LoginSignup
13
2

More than 1 year has passed since last update.

Salesforceがオープンソースで公開した機械学習ライブラリー「TransmogrifAI」でAutoMLしてみる

Posted at

はじめに

現在、@ITでAutoMLのOSSに関する連載をしており、そこでAutoMLが実行できる様々なOSSをGoogle Colabで動作するノートブックとともに紹介しています。
logos.png
その中で紹介できなかったTransmogrifAIについて、この記事で簡単に紹介したいと思います。

TransmogrifAIとは

TransmogrifAI(日本語であれば「トランスモグリファイ」と読みます)は、米Salesforce.comがオープンソースで公開した機械学習自動化ライブラリーで、AutoMLが実行できます。TransmogrifAIの技術は、同社のAIプラットフォームである「Einstein」でも使われているとのことです。
logo_tr.png
TransmogrifAIの公式サイトでは、以下の4つの特徴を挙げています。これらについて意訳と補足をして説明します。

  • 自動化:特徴量エンジニアリングやモデルの選択などを自動化するAutoML機能が実装されている
  • モジュール性と再利用性:機械学習のワークフロー定義とデータ操作を厳密に分離することで、モジュール性とで再利用性の高いコードを書くことができる
  • コンパイル時の型安全性:Scalaで実装されており、Scalaで開発するため、開発中のコード補完や実行時エラーの減少など、コンパイル時の型安全性の多くの利点を享受できる
  • 透明性:機械学習モデルのブラックボックス問題に対応できるような情報(どうしてその答えに至ったかという過程を人間に理解させるための情報)を出力する

その他に、Apache Spark上で動作し、大量のデータを分散して高速に処理できることも大きな特徴です。

この記事では、TransmogrifAIでKaggleの初心者向けコンペとして利用されている「タイタニックの生存予測」を行いたいと思います。

インストール前の準備

TransmogrifAIのインストールには、以下が必要になります。

  • Java
  • Scala
  • Spark

この3つが、TransmogrifAIのREADME.mdに記載されているバージョンであるかを確認して下さい。Ubuntu 20.04にこれらがインストールされていないことを想定して、これから手順を解説していきます。

Javaのインストール

まずはJavaをインストールし、バージョンを確認します。

$ sudo apt install openjdk-8-jdk
$ java -version
openjdk version "1.8.0_292"
OpenJDK Runtime Environment (build 1.8.0_292-8u292-b10-0ubuntu1~20.04-b10)
OpenJDK 64-Bit Server VM (build 25.292-b10, mixed mode)

Scalaのインストール

次にScalaをインストールし、バージョンを確認します。

$ sudo apt install scala
$ scala -version
Scala code runner version 2.11.12 -- Copyright 2002-2017, LAMP/EPFL

Sparkのインストール

最後にSparkをインストールし、インストールディレクトリーを環境変数SPARK_HOMEにセットします。

$ wget https://archive.apache.org/dist/spark/spark-2.4.8/spark-2.4.8-bin-hadoop2.7.tgz
$ tar -zxvf spark-2.4.8-bin-hadoop2.7.tgz
$ export SPARK_HOME=/home/tamura/spark-2.4.8-bin-hadoop2.7

TransmogrifAIのインストール

まずは、以下のコマンドでTransmogrifAIをGitHubからクローンしてビルドします。

$ git clone https://github.com/salesforce/TransmogrifAI.git
$ cd TransmogrifAI/
$ ./gradlew compileTestScala installDist

タイタニックの生存予測の実行

ではタイタニックの生存予測の実行してみましょう。TransmogrifAIには、タイタニックの生存予測を行うためのサンプルが含まれており、以下のコマンドで実行できます。

$ cd helloworld/
$ ./gradlew -q sparkSubmit -Dmain=com.salesforce.hw.OpTitanicSimple -Dargs="\
`pwd`/src/main/resources/TitanicDataset/TitanicPassengersTrainData.csv"

が、その前にcom.salesforce.hw.OpTitanicSimpleクラスの実装を見てみましょう(※抜粋です)。

このクラスにはmain()メソッドがあり、最初に引数のチェックなどをしています。

/**
 * A simplified TransmogrifAI example classification app using the Titanic dataset
 */
object OpTitanicSimple {

  /**
   * Run this from the command line with
   * ./gradlew sparkSubmit -Dmain=com.salesforce.hw.OpTitanicSimple -Dargs=/full/path/to/csv/file
   */
  def main(args: Array[String]): Unit = {
    if (args.isEmpty) {
      println("You need to pass in the CSV file path as an argument")
      sys.exit(1)
    }

次に、データの型の指定をしています。

    // Define features using the OP types based on the data
    val survived = FeatureBuilder.RealNN[Passenger].extract(_.survived.toRealNN).asResponse
    val pClass = FeatureBuilder.PickList[Passenger].extract(_.pClass.map(_.toString).toPickList).asPredictor
    val name = FeatureBuilder.Text[Passenger].extract(_.name.toText).asPredictor
    val sex = FeatureBuilder.PickList[Passenger].extract(_.sex.map(_.toString).toPickList).asPredictor
    val age = FeatureBuilder.Real[Passenger].extract(_.age.toReal).asPredictor
    val sibSp = FeatureBuilder.Integral[Passenger].extract(_.sibSp.toIntegral).asPredictor
    val parCh = FeatureBuilder.Integral[Passenger].extract(_.parCh.toIntegral).asPredictor
    val ticket = FeatureBuilder.PickList[Passenger].extract(_.ticket.map(_.toString).toPickList).asPredictor
    val fare = FeatureBuilder.Real[Passenger].extract(_.fare.toReal).asPredictor
    val cabin = FeatureBuilder.PickList[Passenger].extract(_.cabin.map(_.toString).toPickList).asPredictor
    val embarked = FeatureBuilder.PickList[Passenger].extract(_.embarked.map(_.toString).toPickList).asPredictor

そして、簡単な特徴量エンジニアリングをしています。SibSpparchを加算した数(家族の人数)を新たな列として定義したり、18歳以下はchild、超えていればadultでグループ化したりしています。

    // Do some basic feature engineering using knowledge of the underlying dataset
    val familySize = sibSp + parCh + 1
    val estimatedCostOfTickets = familySize * fare
    val pivotedSex = sex.pivot()
    val normedAge = age.fillMissingWithMean().zNormalize()
    val ageGroup = age.map[PickList](_.value.map(v => if (v > 18) "adult" else "child").toPickList)

さらに使用する特徴量をベクトル化し、

    // Define a feature of type vector containing all the predictors you'd like to use
    val passengerFeatures = Seq(
      pClass, name, age, sibSp, parCh, ticket,
      cabin, embarked, familySize, estimatedCostOfTickets,
      pivotedSex, ageGroup, normedAge
    ).transmogrify()

目的変数の設定や入力データのロードなどをしています。OpLogisticRegressionを指定しているので、ロジスティック回帰のモデルのみで学習します。

    // Define the model we want to use (here a simple logistic regression) and get the resulting output
    val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
      modelTypesToUse = Seq(OpLogisticRegression)
    ).setInput(survived, checkedFeatures).getOutput()

    val evaluator = Evaluators.BinaryClassification().setLabelCol(survived).setPredictionCol(prediction)

    // Define a way to read data into our Passenger class from our CSV file
    val dataReader = DataReaders.Simple.csvCase[Passenger](path = Option(csvFilePath), key = _.id.toString)

    // Define a new workflow and attach our data reader
    val workflow = new OpWorkflow().setResultFeatures(survived, prediction).setReader(dataReader)

ここまで完了したら、train()で学習します。

    // Fit the workflow to the data
    val model = workflow.train()
    println(s"Model summary:\n${model.summaryPretty()}")

最後に、スコアやメトリクス情報などを出力します。

    // Manifest the result features of the workflow
    println("Scoring the model")
    val (scores, metrics) = model.scoreAndEvaluate(evaluator = evaluator)
    println("Metrics:\n" + metrics)

    // Stop Spark gracefully
    spark.stop()

正常に実行できれば、学習に使用したモデルのサマリーや評価指標などが出力されます。内容を少しずつ見ていきましょう。

まずは選択されたモデルに関する表が出力されています。

Model summary:
Evaluated OpLogisticRegression model using Train Validation Split and area under precision-recall metric.
Evaluated 8 OpLogisticRegression models with area under precision-recall metric between [0.603802779975893, 0.8164752136817999].
+--------------------------------------------------------+
|         Selected Model - OpLogisticRegression          |
+--------------------------------------------------------+
| Model Param      | Value                               |
+------------------+-------------------------------------+
| aggregationDepth | 2                                   |
| elasticNetParam  | 0.5                                 |
| family           | auto                                |
| fitIntercept     | true                                |
| maxIter          | 50                                  |
| modelType        | OpLogisticRegression                |
| name             | OpLogisticRegression_00000000001d_7 |
| regParam         | 0.2                                 |
| standardization  | true                                |
| tol              | 1.0E-6                              |
| uid              | OpLogisticRegression_00000000001d   |
+------------------+-------------------------------------+

8パターンのロジスティック回帰のモデルで学習したことが分かります。

次は評価指標に関する表です。

+------------------------------------------------------------------------+
|                        Model Evaluation Metrics                        |
+------------------------------------------------------------------------+
| Metric Name                 | Training Set Value | Hold Out Set Value  |
+-----------------------------+--------------------+---------------------+
| area under ROC              | 0.8464956967048288 | 0.8340909090909091  |
| area under precision-recall | 0.823522409190228  | 0.803515359057996   |
| error                       | 0.2135678391959799 | 0.21052631578947367 |
| f1                          | 0.7098976109215017 | 0.7142857142857143  |
| false negative              | 94.0               | 15.0                |
| false positive              | 76.0               | 5.0                 |
| precision                   | 0.7323943661971831 | 0.8333333333333334  |
| recall                      | 0.6887417218543046 | 0.625               |
| true negative               | 418.0              | 50.0                |
| true positive               | 208.0              | 25.0                |
+-----------------------------+--------------------+---------------------+

「error」が0.21・・・となっているので、正解率(Accuracy)は0.79程度のようです。精度は良さそうですが、正の相関と負の相関の上位を見ると(以下)、nameが入っているのが気になります。

+--------------------------------------------------------------------------------+
|                               Top Model Insights                               |
+--------------------------------------------------------------------------------+
| Top Positive Correlations                              |     Correlation Value |
+--------------------------------------------------------+-----------------------+
| name                                                   |   0.33799375198478787 |
| cabin(cabin = other)                                   |   0.31705405559197924 |
| pClass(pClass = 1)                                     |   0.31243200321983444 |
| embarked(embarked = C)                                 |    0.1763879191607089 |
| parCh                                                  |    0.1566032653405966 |
| fare                                                   |    0.1566032653405966 |
| sibSp                                                  |    0.1566032653405966 |
| age(age_1-stagesApplied_PickList_000000000012 = Child) |   0.10778346927270353 |
| pClass(pClass = 2)                                     |   0.06776248741659234 |
| embarked(embarked = null)                              |   0.06418961185845397 |
| embarked(embarked = Q)                                 | 0.0027272053572141345 |
| age(age_1-stagesApplied_PickList_000000000012 = Adult) | -0.009590292261602037 |
| age                                                    |  -0.07981845335895164 |
| age(age_1-stagesApplied_PickList_000000000012 = null)  |   -0.0880038403991474 |
| embarked(embarked = S)                                 |  -0.16224431250352364 |
+--------------------------------------------------------+-----------------------+
+--------------------------------------------------+
| Top Negative Correlations |    Correlation Value |
+---------------------------+----------------------+
| sex(sex = Male)           |  -0.5418034858913471 |
| name                      |   -0.530656742043113 |
| pClass(pClass = 3)        | -0.32348838530021073 |
| sibSp                     | -0.03255844085788405 |
| parCh                     | 0.019575673849908347 |
+---------------------------+----------------------+

nameに含まれる「Mr.」や「Miss.」などを自動的に読み取って新たな特徴量を生成(特徴量エンジニアリング)していたのであれば、上位であってもおかしくは無いと思いますが、そうだったのでしょうか??うーん...:thinking:

また、予測に対して貢献度の高い特徴量の上位を見ると、上位5件以外はContribution Valueが0になっています。5つの特徴量しか予測に使っていないわけないはずですが...:thinking:謎です。この表の読み方が違うのかもしれませんが、ドキュメントに情報もなく...このあたりはソースコードを見ながら、デバッグをしてみなければ分かりませんが、今回は時間の都合上、そこまではできず...

+-------------------------------------------------------------------------------+
| Top Contributions                                      |   Contribution Value |
+--------------------------------------------------------+----------------------+
| sex(sex = Male)                                        |   0.3052581202448197 |
| name                                                   |    0.274746124920819 |
| pClass(pClass = 3)                                     |  0.07646940868483958 |
| pClass(pClass = 1)                                     | 0.055088076460696654 |
| cabin(cabin = other)                                   |  0.04900907064952525 |
| pClass(pClass = 2)                                     |                  0.0 |
| parCh                                                  |                  0.0 |
| embarked(embarked = null)                              |                  0.0 |
| embarked(embarked = Q)                                 |                  0.0 |
| embarked(embarked = C)                                 |                  0.0 |
| embarked(embarked = S)                                 |                  0.0 |
| age                                                    |                  0.0 |
| age(age_1-stagesApplied_PickList_000000000012 = null)  |                  0.0 |
| age(age_1-stagesApplied_PickList_000000000012 = Child) |                  0.0 |
| age(age_1-stagesApplied_PickList_000000000012 = Adult) |                  0.0 |
+--------------------------------------------------------+----------------------+
+-----------------------------------------------------------------+
| Top CramersV                              |            CramersV |
+-------------------------------------------+---------------------+
| sex                                       |  0.5418034858913474 |
| pClass                                    |  0.3524984173488759 |
| cabin                                     | 0.31705405559197924 |
| embarked                                  | 0.19183600439063905 |
| age_1-stagesApplied_PickList_000000000012 | 0.12671640510007356 |
+-------------------------------------------+---------------------+

最後にメモトリクス情報が出力されます。

Metrics:
{
  "Precision" : 0.7420382165605095,
  "Recall" : 0.6812865497076024,
  "F1" : 0.7103658536585367,
  "AuROC" : 0.8447975585594222,
  "AuPR" : 0.8203340930560883,
  "Error" : 0.2132435465768799,
  "TP" : 233.0,
  "TN" : 468.0,
  "FP" : 81.0,
  "FN" : 109.0,
  "thresholds" : [ 0.6160572103205795, 0.5880701840712025, 0.5851451683603401, 0.5565268974903586, 0.5473829316115263, 0.5183041191804046, 0.45909945558587884, 0.43025238088360374, 0.42729807940552766, 0.3989753086433127, 0.3901426428697533, 0.3627228446516393, 0.3291210862611915, 0.3038531934818624, 0.3013078137509486, 0.2772926616702837, 0.26994393490358304, 0.24754291497497533, 0.20145741533047348, 0.18151374224576056 ],
  "precisionByThreshold" : [ 0.9629629629629629, 0.9680851063829787, 0.9615384615384616, 0.9470588235294117, 0.9375, 0.7420382165605095, 0.739938080495356, 0.7317073170731707, 0.7341389728096677, 0.7217391304347827, 0.7225433526011561, 0.696236559139785, 0.6375545851528385, 0.6171548117154811, 0.6153846153846154, 0.5334507042253521, 0.5305410122164049, 0.38288288288288286, 0.3842696629213483, 0.3838383838383838 ],
  "recallByThreshold" : [ 0.22807017543859648, 0.26608187134502925, 0.29239766081871343, 0.47076023391812866, 0.4824561403508772, 0.6812865497076024, 0.6988304093567251, 0.7017543859649122, 0.7105263157894737, 0.7280701754385965, 0.7309941520467836, 0.7573099415204678, 0.8538011695906432, 0.8625730994152047, 0.8654970760233918, 0.8859649122807017, 0.8888888888888888, 0.9941520467836257, 1.0, 1.0 ],
  "falsePositiveRateByThreshold" : [ 0.00546448087431694, 0.00546448087431694, 0.007285974499089253, 0.01639344262295082, 0.020036429872495445, 0.14754098360655737, 0.15300546448087432, 0.16029143897996356, 0.16029143897996356, 0.17486338797814208, 0.17486338797814208, 0.2058287795992714, 0.302367941712204, 0.3333333333333333, 0.33697632058287796, 0.48269581056466304, 0.4899817850637523, 0.9981785063752276, 0.9981785063752276, 1.0 ]
}

複数の種類のモデルで学習する

今回使用したcom.salesforce.hw.OpTitanicSimpleクラスの実装では、ロジスティック回帰しか使っていないため、一般的に複数の種類のモデルを比較検証する機能がメインとなるAutoMLとは言い難いです。ということで、ロジスティック回帰以外も使用できるようにしてみましょう。

修正するのは以下の部分です。

    // Define the model we want to use (here a simple logistic regression) and get the resulting output
    val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
      modelTypesToUse = Seq(OpLogisticRegression)
    ).setInput(survived, checkedFeatures).getOutput()

この中の次の行を削除してみましょう。

      modelTypesToUse = Seq(OpLogisticRegression)

結果は以下のように変わります。

Model summary:
Evaluated OpLogisticRegression, OpRandomForestClassifier, OpGBTClassifier models using Train Validation Split and area under precision-recall metric.
Evaluated 8 OpLogisticRegression models with area under precision-recall metric between [0.593006491729722, 0.7463705553853126].
Evaluated 18 OpRandomForestClassifier models with area under precision-recall metric between [0.6511902997629747, 0.7729586931534125].
Evaluated 18 OpGBTClassifier models with area under precision-recall metric between [0.6925758416249905, 0.7762236639995906].
+--------------------------------------------------------+
|            Selected Model - OpGBTClassifier            |
+--------------------------------------------------------+
| Model Param           | Value                          |
+-----------------------+--------------------------------+
| cacheNodeIds          | false                          |
| checkpointInterval    | 10                             |
| featureSubsetStrategy | all                            |
| impurity              | gini                           |
| lossType              | logistic                       |
| maxBins               | 32                             |
| maxDepth              | 3                              |
| maxIter               | 20                             |
| maxMemoryInMB         | 256                            |
| minInfoGain           | 0.01                           |
| minInstancesPerNode   | 10                             |
| modelType             | OpGBTClassifier                |
| name                  | OpGBTClassifier_00000000001f_3 |
| seed                  | 1918769265                     |
| stepSize              | 0.1                            |
| subsamplingRate       | 1.0                            |
| uid                   | OpGBTClassifier_00000000001f   |
| validationTol         | 0.01                           |
+-----------------------+--------------------------------+

ランダムフォレストとGBTのモデルも学習に使用されるようになり、ここでは最終的にGBTのモデルが1つが選択されました。ちなみに、この結果は以下のようにSeq()の引数にOpRandomForestClassifierOpGBTClassifierを追加した場合と同じです。

    // Define the model we want to use (here a simple logistic regression) and get the resulting output
    val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
      modelTypesToUse = Seq(OpLogisticRegression, OpRandomForestClassifier, OpGBTClassifier)
    ).setInput(survived, checkedFeatures).getOutput()

さらに決定木やXGBoostなども追加することができます。

    // Define the model we want to use (here a simple logistic regression) and get the resulting output
    val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
      modelTypesToUse = Seq(OpLogisticRegression, OpRandomForestClassifier, OpGBTClassifier, 
        OpDecisionTreeClassifier, OpXGBoostClassifier)
    ).setInput(survived, checkedFeatures).getOutput()

TransmogrifAIを使用してみた感想

実装がScalaでGradleやMavenでビルドや実行を行うため、これらについての知識が少ないと扱うのは難しいように思いました。AutoMLが実行できる他のOSSと比較して、書かなければならないコードの量も多く、癖が強いので、Pythonでコーディングすることが多い機械学習エンジニアにとっては敷居が高いだろうなぁ、というのが正直な感想です。

使いこなせるようになってくると、Spark/Hadoopとの連携しやすさなどのメリットを受けられるのかもしれませんが、多くの人にとってはそこにたどり着くまでの学習コストが高いのではないかと考えます。もちろんScalaが得意な人が機械学習を始めるのであれば、いい選択肢になるとは思います。

参考

13
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
2