はじめに
現在、@ITでAutoMLのOSSに関する連載をしており、そこでAutoMLが実行できる様々なOSSをGoogle Colabで動作するノートブックとともに紹介しています。
その中で紹介できなかったTransmogrifAIについて、この記事で簡単に紹介したいと思います。
TransmogrifAIとは
TransmogrifAI(日本語であれば「トランスモグリファイ」と読みます)は、米Salesforce.comがオープンソースで公開した機械学習自動化ライブラリーで、AutoMLが実行できます。TransmogrifAIの技術は、同社のAIプラットフォームである「Einstein」でも使われているとのことです。
TransmogrifAIの公式サイトでは、以下の4つの特徴を挙げています。これらについて意訳と補足をして説明します。
- 自動化:特徴量エンジニアリングやモデルの選択などを自動化するAutoML機能が実装されている
- モジュール性と再利用性:機械学習のワークフロー定義とデータ操作を厳密に分離することで、モジュール性とで再利用性の高いコードを書くことができる
- コンパイル時の型安全性:Scalaで実装されており、Scalaで開発するため、開発中のコード補完や実行時エラーの減少など、コンパイル時の型安全性の多くの利点を享受できる
- 透明性:機械学習モデルのブラックボックス問題に対応できるような情報(どうしてその答えに至ったかという過程を人間に理解させるための情報)を出力する
その他に、Apache Spark上で動作し、大量のデータを分散して高速に処理できることも大きな特徴です。
この記事では、TransmogrifAIでKaggleの初心者向けコンペとして利用されている「タイタニックの生存予測」を行いたいと思います。
インストール前の準備
TransmogrifAIのインストールには、以下が必要になります。
- Java
- Scala
- Spark
この3つが、TransmogrifAIのREADME.mdに記載されているバージョンであるかを確認して下さい。Ubuntu 20.04にこれらがインストールされていないことを想定して、これから手順を解説していきます。
Javaのインストール
まずはJavaをインストールし、バージョンを確認します。
$ sudo apt install openjdk-8-jdk
$ java -version
openjdk version "1.8.0_292"
OpenJDK Runtime Environment (build 1.8.0_292-8u292-b10-0ubuntu1~20.04-b10)
OpenJDK 64-Bit Server VM (build 25.292-b10, mixed mode)
Scalaのインストール
次にScalaをインストールし、バージョンを確認します。
$ sudo apt install scala
$ scala -version
Scala code runner version 2.11.12 -- Copyright 2002-2017, LAMP/EPFL
Sparkのインストール
最後にSparkをインストールし、インストールディレクトリーを環境変数SPARK_HOME
にセットします。
$ wget https://archive.apache.org/dist/spark/spark-2.4.8/spark-2.4.8-bin-hadoop2.7.tgz
$ tar -zxvf spark-2.4.8-bin-hadoop2.7.tgz
$ export SPARK_HOME=/home/tamura/spark-2.4.8-bin-hadoop2.7
TransmogrifAIのインストール
まずは、以下のコマンドでTransmogrifAIをGitHubからクローンしてビルドします。
$ git clone https://github.com/salesforce/TransmogrifAI.git
$ cd TransmogrifAI/
$ ./gradlew compileTestScala installDist
タイタニックの生存予測の実行
ではタイタニックの生存予測の実行してみましょう。TransmogrifAIには、タイタニックの生存予測を行うためのサンプルが含まれており、以下のコマンドで実行できます。
$ cd helloworld/
$ ./gradlew -q sparkSubmit -Dmain=com.salesforce.hw.OpTitanicSimple -Dargs="\
`pwd`/src/main/resources/TitanicDataset/TitanicPassengersTrainData.csv"
が、その前にcom.salesforce.hw.OpTitanicSimple
クラスの実装を見てみましょう(※抜粋です)。
このクラスにはmain()
メソッドがあり、最初に引数のチェックなどをしています。
/**
* A simplified TransmogrifAI example classification app using the Titanic dataset
*/
object OpTitanicSimple {
/**
* Run this from the command line with
* ./gradlew sparkSubmit -Dmain=com.salesforce.hw.OpTitanicSimple -Dargs=/full/path/to/csv/file
*/
def main(args: Array[String]): Unit = {
if (args.isEmpty) {
println("You need to pass in the CSV file path as an argument")
sys.exit(1)
}
次に、データの型の指定をしています。
// Define features using the OP types based on the data
val survived = FeatureBuilder.RealNN[Passenger].extract(_.survived.toRealNN).asResponse
val pClass = FeatureBuilder.PickList[Passenger].extract(_.pClass.map(_.toString).toPickList).asPredictor
val name = FeatureBuilder.Text[Passenger].extract(_.name.toText).asPredictor
val sex = FeatureBuilder.PickList[Passenger].extract(_.sex.map(_.toString).toPickList).asPredictor
val age = FeatureBuilder.Real[Passenger].extract(_.age.toReal).asPredictor
val sibSp = FeatureBuilder.Integral[Passenger].extract(_.sibSp.toIntegral).asPredictor
val parCh = FeatureBuilder.Integral[Passenger].extract(_.parCh.toIntegral).asPredictor
val ticket = FeatureBuilder.PickList[Passenger].extract(_.ticket.map(_.toString).toPickList).asPredictor
val fare = FeatureBuilder.Real[Passenger].extract(_.fare.toReal).asPredictor
val cabin = FeatureBuilder.PickList[Passenger].extract(_.cabin.map(_.toString).toPickList).asPredictor
val embarked = FeatureBuilder.PickList[Passenger].extract(_.embarked.map(_.toString).toPickList).asPredictor
そして、簡単な特徴量エンジニアリングをしています。SibSp
とparch
を加算した数(家族の人数)を新たな列として定義したり、18歳以下はchild
、超えていればadult
でグループ化したりしています。
// Do some basic feature engineering using knowledge of the underlying dataset
val familySize = sibSp + parCh + 1
val estimatedCostOfTickets = familySize * fare
val pivotedSex = sex.pivot()
val normedAge = age.fillMissingWithMean().zNormalize()
val ageGroup = age.map[PickList](_.value.map(v => if (v > 18) "adult" else "child").toPickList)
さらに使用する特徴量をベクトル化し、
// Define a feature of type vector containing all the predictors you'd like to use
val passengerFeatures = Seq(
pClass, name, age, sibSp, parCh, ticket,
cabin, embarked, familySize, estimatedCostOfTickets,
pivotedSex, ageGroup, normedAge
).transmogrify()
目的変数の設定や入力データのロードなどをしています。OpLogisticRegression
を指定しているので、ロジスティック回帰のモデルのみで学習します。
// Define the model we want to use (here a simple logistic regression) and get the resulting output
val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
modelTypesToUse = Seq(OpLogisticRegression)
).setInput(survived, checkedFeatures).getOutput()
val evaluator = Evaluators.BinaryClassification().setLabelCol(survived).setPredictionCol(prediction)
// Define a way to read data into our Passenger class from our CSV file
val dataReader = DataReaders.Simple.csvCase[Passenger](path = Option(csvFilePath), key = _.id.toString)
// Define a new workflow and attach our data reader
val workflow = new OpWorkflow().setResultFeatures(survived, prediction).setReader(dataReader)
ここまで完了したら、train()
で学習します。
// Fit the workflow to the data
val model = workflow.train()
println(s"Model summary:\n${model.summaryPretty()}")
最後に、スコアやメトリクス情報などを出力します。
// Manifest the result features of the workflow
println("Scoring the model")
val (scores, metrics) = model.scoreAndEvaluate(evaluator = evaluator)
println("Metrics:\n" + metrics)
// Stop Spark gracefully
spark.stop()
正常に実行できれば、学習に使用したモデルのサマリーや評価指標などが出力されます。内容を少しずつ見ていきましょう。
まずは選択されたモデルに関する表が出力されています。
Model summary:
Evaluated OpLogisticRegression model using Train Validation Split and area under precision-recall metric.
Evaluated 8 OpLogisticRegression models with area under precision-recall metric between [0.603802779975893, 0.8164752136817999].
+--------------------------------------------------------+
| Selected Model - OpLogisticRegression |
+--------------------------------------------------------+
| Model Param | Value |
+------------------+-------------------------------------+
| aggregationDepth | 2 |
| elasticNetParam | 0.5 |
| family | auto |
| fitIntercept | true |
| maxIter | 50 |
| modelType | OpLogisticRegression |
| name | OpLogisticRegression_00000000001d_7 |
| regParam | 0.2 |
| standardization | true |
| tol | 1.0E-6 |
| uid | OpLogisticRegression_00000000001d |
+------------------+-------------------------------------+
8パターンのロジスティック回帰のモデルで学習したことが分かります。
次は評価指標に関する表です。
+------------------------------------------------------------------------+
| Model Evaluation Metrics |
+------------------------------------------------------------------------+
| Metric Name | Training Set Value | Hold Out Set Value |
+-----------------------------+--------------------+---------------------+
| area under ROC | 0.8464956967048288 | 0.8340909090909091 |
| area under precision-recall | 0.823522409190228 | 0.803515359057996 |
| error | 0.2135678391959799 | 0.21052631578947367 |
| f1 | 0.7098976109215017 | 0.7142857142857143 |
| false negative | 94.0 | 15.0 |
| false positive | 76.0 | 5.0 |
| precision | 0.7323943661971831 | 0.8333333333333334 |
| recall | 0.6887417218543046 | 0.625 |
| true negative | 418.0 | 50.0 |
| true positive | 208.0 | 25.0 |
+-----------------------------+--------------------+---------------------+
「error」が0.21・・・となっているので、正解率(Accuracy)は0.79程度のようです。精度は良さそうですが、正の相関と負の相関の上位を見ると(以下)、name
が入っているのが気になります。
+--------------------------------------------------------------------------------+
| Top Model Insights |
+--------------------------------------------------------------------------------+
| Top Positive Correlations | Correlation Value |
+--------------------------------------------------------+-----------------------+
| name | 0.33799375198478787 |
| cabin(cabin = other) | 0.31705405559197924 |
| pClass(pClass = 1) | 0.31243200321983444 |
| embarked(embarked = C) | 0.1763879191607089 |
| parCh | 0.1566032653405966 |
| fare | 0.1566032653405966 |
| sibSp | 0.1566032653405966 |
| age(age_1-stagesApplied_PickList_000000000012 = Child) | 0.10778346927270353 |
| pClass(pClass = 2) | 0.06776248741659234 |
| embarked(embarked = null) | 0.06418961185845397 |
| embarked(embarked = Q) | 0.0027272053572141345 |
| age(age_1-stagesApplied_PickList_000000000012 = Adult) | -0.009590292261602037 |
| age | -0.07981845335895164 |
| age(age_1-stagesApplied_PickList_000000000012 = null) | -0.0880038403991474 |
| embarked(embarked = S) | -0.16224431250352364 |
+--------------------------------------------------------+-----------------------+
+--------------------------------------------------+
| Top Negative Correlations | Correlation Value |
+---------------------------+----------------------+
| sex(sex = Male) | -0.5418034858913471 |
| name | -0.530656742043113 |
| pClass(pClass = 3) | -0.32348838530021073 |
| sibSp | -0.03255844085788405 |
| parCh | 0.019575673849908347 |
+---------------------------+----------------------+
name
に含まれる「Mr.」や「Miss.」などを自動的に読み取って新たな特徴量を生成(特徴量エンジニアリング)していたのであれば、上位であってもおかしくは無いと思いますが、そうだったのでしょうか??うーん...。
また、予測に対して貢献度の高い特徴量の上位を見ると、上位5件以外はContribution Value
が0になっています。5つの特徴量しか予測に使っていないわけないはずですが...謎です。この表の読み方が違うのかもしれませんが、ドキュメントに情報もなく...このあたりはソースコードを見ながら、デバッグをしてみなければ分かりませんが、今回は時間の都合上、そこまではできず...
+-------------------------------------------------------------------------------+
| Top Contributions | Contribution Value |
+--------------------------------------------------------+----------------------+
| sex(sex = Male) | 0.3052581202448197 |
| name | 0.274746124920819 |
| pClass(pClass = 3) | 0.07646940868483958 |
| pClass(pClass = 1) | 0.055088076460696654 |
| cabin(cabin = other) | 0.04900907064952525 |
| pClass(pClass = 2) | 0.0 |
| parCh | 0.0 |
| embarked(embarked = null) | 0.0 |
| embarked(embarked = Q) | 0.0 |
| embarked(embarked = C) | 0.0 |
| embarked(embarked = S) | 0.0 |
| age | 0.0 |
| age(age_1-stagesApplied_PickList_000000000012 = null) | 0.0 |
| age(age_1-stagesApplied_PickList_000000000012 = Child) | 0.0 |
| age(age_1-stagesApplied_PickList_000000000012 = Adult) | 0.0 |
+--------------------------------------------------------+----------------------+
+-----------------------------------------------------------------+
| Top CramersV | CramersV |
+-------------------------------------------+---------------------+
| sex | 0.5418034858913474 |
| pClass | 0.3524984173488759 |
| cabin | 0.31705405559197924 |
| embarked | 0.19183600439063905 |
| age_1-stagesApplied_PickList_000000000012 | 0.12671640510007356 |
+-------------------------------------------+---------------------+
最後にメモトリクス情報が出力されます。
Metrics:
{
"Precision" : 0.7420382165605095,
"Recall" : 0.6812865497076024,
"F1" : 0.7103658536585367,
"AuROC" : 0.8447975585594222,
"AuPR" : 0.8203340930560883,
"Error" : 0.2132435465768799,
"TP" : 233.0,
"TN" : 468.0,
"FP" : 81.0,
"FN" : 109.0,
"thresholds" : [ 0.6160572103205795, 0.5880701840712025, 0.5851451683603401, 0.5565268974903586, 0.5473829316115263, 0.5183041191804046, 0.45909945558587884, 0.43025238088360374, 0.42729807940552766, 0.3989753086433127, 0.3901426428697533, 0.3627228446516393, 0.3291210862611915, 0.3038531934818624, 0.3013078137509486, 0.2772926616702837, 0.26994393490358304, 0.24754291497497533, 0.20145741533047348, 0.18151374224576056 ],
"precisionByThreshold" : [ 0.9629629629629629, 0.9680851063829787, 0.9615384615384616, 0.9470588235294117, 0.9375, 0.7420382165605095, 0.739938080495356, 0.7317073170731707, 0.7341389728096677, 0.7217391304347827, 0.7225433526011561, 0.696236559139785, 0.6375545851528385, 0.6171548117154811, 0.6153846153846154, 0.5334507042253521, 0.5305410122164049, 0.38288288288288286, 0.3842696629213483, 0.3838383838383838 ],
"recallByThreshold" : [ 0.22807017543859648, 0.26608187134502925, 0.29239766081871343, 0.47076023391812866, 0.4824561403508772, 0.6812865497076024, 0.6988304093567251, 0.7017543859649122, 0.7105263157894737, 0.7280701754385965, 0.7309941520467836, 0.7573099415204678, 0.8538011695906432, 0.8625730994152047, 0.8654970760233918, 0.8859649122807017, 0.8888888888888888, 0.9941520467836257, 1.0, 1.0 ],
"falsePositiveRateByThreshold" : [ 0.00546448087431694, 0.00546448087431694, 0.007285974499089253, 0.01639344262295082, 0.020036429872495445, 0.14754098360655737, 0.15300546448087432, 0.16029143897996356, 0.16029143897996356, 0.17486338797814208, 0.17486338797814208, 0.2058287795992714, 0.302367941712204, 0.3333333333333333, 0.33697632058287796, 0.48269581056466304, 0.4899817850637523, 0.9981785063752276, 0.9981785063752276, 1.0 ]
}
複数の種類のモデルで学習する
今回使用したcom.salesforce.hw.OpTitanicSimple
クラスの実装では、ロジスティック回帰しか使っていないため、一般的に複数の種類のモデルを比較検証する機能がメインとなるAutoMLとは言い難いです。ということで、ロジスティック回帰以外も使用できるようにしてみましょう。
修正するのは以下の部分です。
// Define the model we want to use (here a simple logistic regression) and get the resulting output
val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
modelTypesToUse = Seq(OpLogisticRegression)
).setInput(survived, checkedFeatures).getOutput()
この中の次の行を削除してみましょう。
modelTypesToUse = Seq(OpLogisticRegression)
結果は以下のように変わります。
Model summary:
Evaluated OpLogisticRegression, OpRandomForestClassifier, OpGBTClassifier models using Train Validation Split and area under precision-recall metric.
Evaluated 8 OpLogisticRegression models with area under precision-recall metric between [0.593006491729722, 0.7463705553853126].
Evaluated 18 OpRandomForestClassifier models with area under precision-recall metric between [0.6511902997629747, 0.7729586931534125].
Evaluated 18 OpGBTClassifier models with area under precision-recall metric between [0.6925758416249905, 0.7762236639995906].
+--------------------------------------------------------+
| Selected Model - OpGBTClassifier |
+--------------------------------------------------------+
| Model Param | Value |
+-----------------------+--------------------------------+
| cacheNodeIds | false |
| checkpointInterval | 10 |
| featureSubsetStrategy | all |
| impurity | gini |
| lossType | logistic |
| maxBins | 32 |
| maxDepth | 3 |
| maxIter | 20 |
| maxMemoryInMB | 256 |
| minInfoGain | 0.01 |
| minInstancesPerNode | 10 |
| modelType | OpGBTClassifier |
| name | OpGBTClassifier_00000000001f_3 |
| seed | 1918769265 |
| stepSize | 0.1 |
| subsamplingRate | 1.0 |
| uid | OpGBTClassifier_00000000001f |
| validationTol | 0.01 |
+-----------------------+--------------------------------+
ランダムフォレストとGBTのモデルも学習に使用されるようになり、ここでは最終的にGBTのモデルが1つが選択されました。ちなみに、この結果は以下のようにSeq()
の引数にOpRandomForestClassifier
とOpGBTClassifier
を追加した場合と同じです。
// Define the model we want to use (here a simple logistic regression) and get the resulting output
val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
modelTypesToUse = Seq(OpLogisticRegression, OpRandomForestClassifier, OpGBTClassifier)
).setInput(survived, checkedFeatures).getOutput()
さらに決定木やXGBoostなども追加することができます。
// Define the model we want to use (here a simple logistic regression) and get the resulting output
val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
modelTypesToUse = Seq(OpLogisticRegression, OpRandomForestClassifier, OpGBTClassifier,
OpDecisionTreeClassifier, OpXGBoostClassifier)
).setInput(survived, checkedFeatures).getOutput()
TransmogrifAIを使用してみた感想
実装がScalaでGradleやMavenでビルドや実行を行うため、これらについての知識が少ないと扱うのは難しいように思いました。AutoMLが実行できる他のOSSと比較して、書かなければならないコードの量も多く、癖が強いので、Pythonでコーディングすることが多い機械学習エンジニアにとっては敷居が高いだろうなぁ、というのが正直な感想です。
使いこなせるようになってくると、Spark/Hadoopとの連携しやすさなどのメリットを受けられるのかもしれませんが、多くの人にとってはそこにたどり着くまでの学習コストが高いのではないかと考えます。もちろんScalaが得意な人が機械学習を始めるのであれば、いい選択肢になるとは思います。