Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

複数のmlmodelを統一的に扱う

複数のmlmodelを統一的に扱う

はじめに

僕は機械学習系の研究室に所属している大学院生で、普段はPythonを用いてコードを書いています。また、Apple信者でもあるのでSwiftでアプリを作ることもあります。

SwiftにはCore ML1というApple製の機械学習ライブラリが存在します。Core MLは、mlmodelファイルをプロジェクトにドラッグ&ドロップするだけでモデルを使うためのコードが自動生成されるため、とても便利です。一方で、複数のモデルを扱うときに少し不便なことがあります。

機械学習系の研究をする際、複数の機械学習モデルを比較することが大いにあると思います。Pythonの場合は、

models = {
    "Random Forest": RandomForestClassifier(),
    "SVM": SVC(),
    "VGG16": ...
}

for model in models:
    model.fit()

    outputs = model.predict()
    ...

のようにして、辞書(もしくは配列)にしておくことでfor文で複数のモデルを実行することができます。

しかし、Swiftの場合、静的型付けであるため動的型付けであるPythonのように違うクラス・型のものを辞書や配列に入れて扱うのは難しいと思います。そこで今回は、複数のCore MLモデル(mlmodel)を比較したり切り替えたいというときに、コードを共通化する方法を記事にしたいと思います。

(今回は、Visionを用いたものではなくMLMultiArray2を入力とするCore MLモデルを想定しています)

環境

  • Xcode 12.2
  • macOS 11.0.1

mlmodelを統一的に扱う

プロトコルを作る

MyModel.mlmodelファイルをXcodeにドラッグ&ドロップすると自動的にSwiftのコードが生成され、MyModelInputMyModelOutputMyModelの3つのクラスができます。MyModelInputMyModelOutputMLFeatureProvider3というプロトコルに準拠していますが、MyModelMLModelを持っているだけで単体のクラスとして生成されます。

そこで、プロトコルを作成しMyModelを拡張して準拠させることで統一したルールで扱うことができるようにします。また、生成されたコードのpredictionの戻り値はMyModelOutputでありmlmodelごとに異なるため、predictionの戻り値も統一的に扱えるようにします。今回は戻り値をStringにしましたが、クラスラベルの列挙型を作成しても良いと思います。

protocol MLModelUnification {
    func prediction(input: MLMultiArray) throws -> String
    func predictions(inputs: [MLMultiArray]) throws -> [String]
}

モデルのクラスを拡張する

次に作成したプロトコルをモデルのクラスに準拠させるためにモデルのクラスを拡張します。

VGG16.mlmodelResNet50.mlmodelの2つのモデルがあった場合の例を示します。2つのモデルの出力にあたるVGG16OutputResNet50Outputは、

  • classLabel: String: 予測ラベル
  • Identity: [String : Double]: 各ラベルの予測確率

を持っているとします。

VGG16のextension
extension VGG16: MLModelUnification {
    func prediction(input: MLMultiArray) throws -> String {
        let output = try self.prediction(input: VGG16Input(input: input))
            return output.classLabel
    }

    func predictions(inputs: [MLMultiArray]) throws -> [String] {
        var results: [String] = []
        try inputs.forEach { (input) in
                let output = try self.prediction(input: VGG16Input(input: input))
                results.append(output.classLabel)
            }
            return results
    }
}
ResNet50のextension
extension ResNet50: MLModelUnification {
    func prediction(input: MLMultiArray) throws -> String {
        let output = try self.prediction(input: ResNet50Input(input: input))
            return output.classLabel
    }

    func predictions(inputs: [MLMultiArray]) throws -> [String] {
        var results: [String] = []
        try inputs.forEach { (input) in
                let output = try self.prediction(input: ResNet50Input(input: input))
                results.append(output.classLabel)
            }
            return results
    }
}

複数のモデルで共通のコードを使うことができる

このようにプロトコルを作成しそれぞれのモデルのクラスを準拠させることで、予測部分のコードを共通化することができます。

共通化の例(モデルを比較)
var models: [MLModelUnification] = [
            {
                do {
                    return try VGG16(configuration: config)
                } catch {
                    fatalError("Couldn't create VGG16")
                }
            }(),
            {
                do {
                    return try ResNet18(configuration: config)
                } catch {
                    fatalError("Couldn't create ResNet 18")
                }
            }()
]

for model in models {
    let outputs = try model.prediction(inputs: inputs)
    ...
}
共通化の例(モデルの切り替え)
var model: MLModelUnification

// モデルを列挙したenumでswitch文
switch selectedModel {
 case .vgg16:
    model = {
        do {
            return try VGG16(configuration: config)
        } catch {
            print(error)
            fatalError("Couldn't create VGG16")
        }
    }()
case .resnet18:
    model = {
        do {
            return try ResNet18(configuration: config)
        } catch {
            print(error)
            fatalError("Couldn't create ResNet 18")
        }
    }()
}

let output = try model.prediction(input: input)
...

おわりに

今回は、Core MLモデル(mlmodel)で生成される機械学習モデルのクラスを統一的に扱う方法についてまとめました。個人的に、Swiftの便利なところの1つはextensionだなぁと思いました。同じようなことをしようとしている記事を見かけなかったので、CoreML上では複数のモデルを切り替えたり比較することがあまりないかも知れませんが、もし同じようなことがしたい方がいれば参考にしてくれると嬉しいです。

M1 Mac mini欲しさに、「ちょっとした工夫で効率化!【PR】パソナテック Advent Calendar 2020」に向けて記事を書きました。初めてQiita記事を書いたため、拙い部分もあるかと思いますが、ここまで読んでくれてありがとうございました。

参考記事

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away