Help us understand the problem. What is going on with this article?

Create MLを使ってSwiftで機械学習モデルを構築する

More than 1 year has passed since last update.

この記事はJX通信社 Advent Calendar 2018の13日目です。

WWDC 2018で発表されたCreate MLについて、今回は表形式データを使った住宅価格を予測する機械学習モデルの構築を例に紹介していきます。

Create ML

Create MLは機械学習モデルを構築するためのフレームワークです。
Xcode 10のPlaygroundで簡単にモデルを構築することができます。

ワークフロー

Create ML - Overview

学習データ(Data)

住宅価格を予測するモデルを構築するためにKaggleで公開されているデータを使ってトレーニングを行います。
Boston Housing(ボストン近郊の不動産情報)のtrain.csvを使って実際にコードを書いていきます。CSV内のデータ構成は次の通りです。medvの予測をするモデルを構築するのが今回の目的です。

カラム 説明
crim 一人当たりの犯罪率
zn 25,000平方フィート以上の住宅地の割合
indus 非小売業の土地面積の割合
chas チャールズ川沿いか否か
nox 窒素酸化物濃度
rm 平均部屋数
age 1940年以前に建てられた住宅の割合
dis 5つのボストン雇用センターまでの距離の加重平均
rad 高速道路へのアクセスのしやすさ
tax 10,000ドルあたりの固定資産税の税率
ptratio 生徒と教師の比率
black 黒人居住者の割合
lstat 低所得者の割合
medv 住宅価格の中央値(1000ドル単位)

トレーニング(Training)

データが用意できたら次はトレーニングです。
CSVファイルを読み込んだ後に、randomSplit(by:seed:)でトレーニングとテスト用のサブセットに分割します。CSVが元からトレーニングとテスト用に分かれている場合は不要です。
あとはアルゴリズムを指定してRun Playgroundでトレーニングを開始します。

import CreateML
import Foundation

do {
    let trainingCSV = URL(fileURLWithPath: "path/to/train.csv")
    let houseData = try MLDataTable(contentsOf: trainingCSV)
    let (trainingData, testData) = houseData.randomSplit(by: 0.8, seed: 0)

    let housePricer = try MLLinearRegressor(trainingData: trainingData, targetColumn: "medv")
} catch {
    print(error)
}

サポートされているアルゴリズム

回帰

https://developer.apple.com/documentation/createml/mlregressor

手法 Type
線形回帰 MLLinearRegressor
決定木 MLDecisionTreeRegressor
ランダムフォレスト MLRandomForestRegressor
ブースティング木 MLBoostedTreeRegressor

分類

https://developer.apple.com/documentation/createml/mlclassifier

手法 Type
決定木 MLDecisionTreeClassifier
ランダムフォレスト MLRandomForestClassifier
ブースティング木 MLBoostedTreeClassifier
ロジスティック回帰 MLLogisticRegressionClassifier
サポートベクターマシン MLSupportVectorClassifier

モデルの評価(Evaluation)

トレーニングしたモデルのパフォーマンスはmetricsで確認することができます。

import CreateML
import Foundation

do {
    let trainingCSV = URL(fileURLWithPath: "path/to/train.csv")
    let houseData = try MLDataTable(contentsOf: trainingCSV)
    let (trainingData, testData) = houseData.randomSplit(by: 0.8, seed: 0)

    let housePricer = try MLLinearRegressor(trainingData: trainingData, targetColumn: "medv")

    let metrics = housePricer.evaluation(on: testData)

    // Max error: 23.29
    // Root mean squared error: 5.67
    print(metrics)
} catch {
    print(error)
}

回帰アルゴリズムの場合、metricsには予測値とテストデータの値との誤差が記録されています。
最大誤差(Max Error)と二乗平均平方根誤差(Root-Mean-Square Error)でモデルのパフォーマンスを測定し、その誤差を最小化することがトレーニングの目的です。

メトリクスについては、Apple DeveloperサイトのDocumentationページよりも、AppleがCreate MLと同じく機械学習ツールとして提供しているTuri Createドキュメントの説明がより分かりやすいです。

モデルの書き出しとアプリへの組み込み

学習が完了したらモデルを書き出してアプリに組み込みましょう。
パスを指定してwrite(to:)またはwrite(toFile:)で書き出します。

import CreateML
import Foundation

do {
    let trainingCSV = URL(fileURLWithPath: "path/to/train.csv")
    let houseData = try MLDataTable(contentsOf: trainingCSV)
    let (trainingData, testData) = houseData.randomSplit(by: 0.8, seed: 0)

    let housePricer = try MLLinearRegressor(trainingData: trainingData, targetColumn: "medv")

    // ...

    try housePricer.write(to: URL(fileURLWithPath: OUTPUT_PATH))
} catch {
    print(error)
}

書き出したモデルをアプリのプロジェクトに追加したら次のように呼び出します。
プロジェクトに追加した時に、モデルクラスとInput/Output用の構造体が自動生成されています。

import CoreML

do {
    let input = HousePricerInput(
        ID: 3,
        crim: 0.02729,
        zn: 0,
        indus: 7.07,
        chas: 0,
        nox: 0.469,
        rm: 7.185,
        age: 61.1,
        dis: 4.9671,
        rad: 2,
        tax: 242,
        ptratio: 17.8,
        black: 392.83,
        lstat: 4.03)

    let housePricer = try HousePricer(contentsOf: HousePricer.urlOfModelInThisBundle)
    let output = try housePricer.prediction(input: input)
    print(output.medv)
} catch {
    print(error)
}

まとめ

Create MLについて表形式データを使ったサンプルを紹介しました。
今はまだ情報が少なく、Create MLを使った画像分類の紹介記事はいくつかあるのですが、テキストと表形式データを使った記事が少ないと感じていたため今回の記事を書くことにしました。
機械学習に興味はあるものの、なかなか手が付けられずにいたアプリエンジニアの方々の学習のきっかけになれば幸いです。

WWDC 2017でのCore MLの発表、そして2018でのCreate MLとAppleが力を入れている分野なので、このタイミングで機械学習に触れておけば次のWWDCはより楽しめるものになるはずです。

rychhr
jxpress
技術力で「ニュースの産業革命」を起こす。言語処理・データ解析分野の専門家が集まる、News Techベンチャー。
https://jxpress.net/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away