2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

mlr3を使ってみる(その1)

Last updated at Posted at 2020-09-09

1.はじめに

mlrとはRにおける機械学習のフレームワークのことで、mlrの他に'tidymodels'などのフレームワークがあります。

【R】tidymodelsとworkflowを中心に〜機械学習のフレームワーク

そのmlrは2019年からメンテナンスモードになっており、著者らは変わりにmlr3を推奨しているところですが、mlr3については、あまり日本語による情報はないようなので紹介してみます。
 リニューアルしました!
更新:2024年11月現在

mlr3_ecosystem.png
図1: mlr3のエコシステムの概要:グレーの破線のパッケージはまだ開発中であり、それ以外は安定したインターフェースを持っています。

mlrとの最大の違いはmlr3は完全なオブジェクト指向になったことです。
mlr3はmlrをエコシステム、R6やdata.tableに対応した統合的なパッケージとなりました。

※R6とは完全にオブジェクト指向に対応するためにプログラミング言語のひとつ
CRANで’R6'パッケージで提供されている

最初にmlr3で使われているオブジェクト指向R6とデーターベース:data.tableについて簡単に紹介します。

1.1 R6の紹介

R6は、オブジェクト指向プログラミングのためのRのより新しいパラダイムの一つです。オブジェクト指向プログラミングの経験があれば(例えばpythonのクラスとか)、R6を身近に感じることができる
ここでは簡単に紹介します。

library(R6)
# R6クラスの定義をする。nameとageのフィールドを持ち、greetのメソッドを持つ
Person = R6Class("Person", 
                 public = list( name = NULL,
                                age = NULL,
                                initialize = function(name, age) 
                                  { self$name = name 
                                    self$age = age }, 
                                greet = function()
                                  { cat("Hello, my name is",
                                        self$name, "and I am", 
                                        self$age, "years old.\n") } )
# インスタンスの作成 
# $new()という初期化メソッドを使ってR6Classのインスタンスを作成する
john = Person$new(name = "John", age = 30)  
#メッソッドの実行
#メソッドによって、ユーザはオブジェクトの状態を調べたり、情報を取得したり、オブジェクトの内部状態を変更するアクションを実行したりすることができる
john$greet() 

#> Hello, my name is John and I am 30 years old.
# R6オブジェクトは環境であり、参照セマンティクスを持つ
# オブジェクトをコピーするには、$clone(deep = TRUE)メソッドを使う
Ann = john$clone()
Ann$name = "Ann"   # 名前は変更したが、ageのフィールドは元(30)のまま保持されている
Ann$greet()

#> Hello, my name is Ann and I am 30 years old.

1.2 data.table の紹介

data.tableパッケージはdata.table()を実装しており、Rのdata.frame()に代わるものとして人気がでてきています。data.tableは非常に高速で、大きなデータにも対応できます。

library(data.table)
# using data.table
dt = data.table(x = 1:6, y = rep(letters[1:3], each = 2))
dt


#>       x      y
#>   <int> <char>
#>1:     1      a
#>2:     2      a
#>3:     3      b
#>4:     4      b
#>5:     5      c
#>6:     6      c

data.tablesはdata.frameと同じように使用できますが、複雑な操作を簡単にする追加機能があります。例えば、[ 演算子の by 引数でデータをグループごとにまとめたり、:= 演算子でインプレースで変更したりすることができます。

# yをグループ化してx列のデータの平均を求める
dt[, mean(x), by = "y"]

#>        y    V1
#>   <char> <num>
#>1:      a   1.5
#>2:      b   3.5
#>3:      c   5.5

2.タスクの作成

タスクは、機械学習問題を定義する(通常は表形式の)データと追加のメタデータを含むオブジェクトです。メタデータには、例えば教師あり機械学習問題のターゲット特徴の名前が含ます。R6オブジェクトで定義します。

データ:「良い」ローンと「悪い」ローンの分類問題を用います(tidymodelsパッケージから)

library(mlr3)
library(tidymodels)
data('credit_data')
head(credit_data)

image.png

分類タスクを作成するには、TaskClassifを用います。
backend フィールドにデータ、target フィールドに目的変数を指定します。

library(mlr3)
library(mlr3learners)
task_credit = TaskClassif$new(id = "credit", backend = credit_data, target = "Status")
task_credit

#> <TaskClassif:credit> (4454 x 14)
#> * Target: Status
#> * Properties: twoclass
#> * Features (13):
#>   - int (9): Age, Amount, Assets, Debt, Expenses, Income, Price,
#>     Seniority, Time
#>   - fct (4): Home, Job, Marital, Records

mlr3の良いところは、データの情報が簡単にコンパクトに表示されるところです。
例えば、欠測値を調べたいとすると・・

task_credit$missings()

#>   Status       Age    Amount    Assets      Debt  Expenses      Home 
#>        0         0         0        47        18         0         6 
#>   Income       Job   Marital     Price   Records Seniority      Time 
#>      381         2         1         0         0         0         0 

簡単にわかります。

さらに、ターゲットとの特徴量の関係を知りたい。
("Status", "Age","Amount" , "Assets" , "Debt", "Expenses" , "Home")に絞って実行します。

library(mlr3viz)
task_credit$select(task_credit$feature_names[1:6]) #selectメソッドで列選択
autoplot(task_credit,type = "pairs")

image.png

mlr3vizのautoplotを用いました。
ただし、気をつけないとR6オブジェクトは参照セマンティクスなので、メソッドを実行すると本体が変更されます

task_credit

#> <TaskClassif:credit> (4454 x 7)
#> * Target: Status
#> * Properties: twoclass
#> * Features (6):
#>   - int (5): Age, Amount, Assets, Debt, Expenses
#>   - fct (1): Home

本体を変更する意図がない場合は、クローンしたものを使いましょう。
(task_creditのデータを元に戻して以下、進みます)

3.Learnerの設定

mlr3ではLearner(学習器)に機械学習で用いるアルゴリズムを指定します。
ここで、フレームワークごとに設定方法は異なるのですが、例えば’rpart’パッケージを設定したいと思っても
どのようにkeyを設定したらよいかいちいち覚えていないのですが、
mlr3では辞書機能があるので、便利です

mlr_learners

#> <DictionaryLearner> with 27 stored values
#> Keys: classif.cv_glmnet, classif.debug, classif.featureless,
#>   classif.glmnet, classif.kknn, classif.lda, classif.log_reg,
#>   classif.multinom, classif.naive_bayes, classif.nnet, classif.qda,
#>   classif.ranger, classif.rpart, classif.svm, classif.xgboost,
#>   regr.cv_glmnet, regr.debug, regr.featureless, regr.glmnet,
#>   regr.kknn, regr.km, regr.lm, regr.nnet, regr.ranger, regr.rpart,
#>   regr.svm, regr.xgboost

classifが分類、regrが回帰モデルになります。
classif.rpartと書けばよいことがわかります。上記では27なので種類が少ないように思えますが、mlr3extralearnersを使うと、一気に152まで増えますので、ほぼ必要なものは揃っています。
※機械学習で有名なlightGBMやCatboostも扱えます

学習器の設定
sugar関数(短縮形)lrn( )でLearnerを設定します。(R6クラス)

lrn_rpart = lrn("classif.rpart")
lrn_rpart

#> <LearnerClassifRpart:classif.rpart>: Classification Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types:  [response], prob
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, multiclass, selected_features,
#>   twoclass, weights

すべてのLearnerオブジェクトには以下のメタデータが含まれています:

  • feature_types:Learnerが扱える機能のタイプ。
  • package: Learnerを使用するためにインストールが必要なパッケージ。
  • properties 学習器のプロパティ。例えば、"missings "プロパティはモデルが欠損データを扱えることを意味し、"importance "は各特徴の相対的な重要度を計算できることを意味します。
  • predict_types:モデルが予測できるタイプ。[ ]が現在の状態。上記では response{0,1}になっている

使用可能な使用可能なハイパーパラメータの集合を表示する

rn_rpart$param_set

image.png

rpartは欠測値が扱えることを確認できました。
この機能も地味に便利。機械学習のアルゴリズムの詳細をいちいち覚えていないので。

4 データの分割と学習

データを訓練データとテストデータに分割します。

splits = partition(task_credit, ratio = 0.7)

訓練データを学習させます。
partitionは、行IDを分割しているだけなので※行IDを指定します。
※行番号ではない。pythonのpandasモジュールのindexに相当する。

lrn_rpart$train(task_credit, row_ids = splits$train)

5 テストデータと評価

テストデータを予測します。

pred = lrn_rpart$predict(task_credit, row_ids = splits$test)

ここで予測結果について、真実とどれだけ離れているかについて、評価する必要が有りますが、この評価方法(measure)も、設定方法のkeyをいちいち覚えられませんが、辞書機能で簡単に調べられます。

mlr_measures

#> <DictionaryMeasure> with 65 stored values
#> Keys: aic, bic, classif.acc, classif.auc, classif.bacc,
#>   classif.bbrier, classif.ce, classif.costs, classif.dor,
#>   classif.fbeta, classif.fdr, classif.fn, classif.fnr, classif.fomr,
#>   classif.fp, classif.fpr, classif.logloss, classif.mauc_au1p,
#>   classif.mauc_au1u, classif.mauc_aunp, classif.mauc_aunu,
#>   classif.mauc_mu, classif.mbrier, classif.mcc, classif.npv,
#>   classif.ppv, classif.prauc, classif.precision, classif.recall,
#>   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
#>   classif.tp, classif.tpr, debug_classif, internal_valid_score,
#>   oob_error, regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae,
#>   regr.medae, regr.medse, regr.mse, regr.msle, regr.pbias,
#>   regr.pinball, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
#>   regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
#>   selected_features, sim.jaccard, sim.phi, time_both, time_predict,
#>   time_train

上記の中から、精度(ACC)で評価します。

measure = msr("classif.acc")
measure

#> <MeasureClassifSimple:classif.acc>: Classification Accuracy
#> * Packages: mlr3, mlr3measures
#> * Range: [0, 1]
#> * Minimize: FALSE
#> * Average: macro
#> * Parameters: list()
#> * Properties: -
#> * Predict type: response

mlr3では評価指標についても簡単に概要を見ることができます。
この場合、predict_typeはresponseに必ず指定していないといけないことがわかります。
評価する場合はメソッドscore()を使います。

pred$score(measure)

#> classif.acc 
#>   0.7612275 

精度は0.76となりました。

6.参考

Applied Machine Learning Using mlr3 in R

7.enjoy

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?