LoginSignup
8
6

More than 5 years have passed since last update.

ML kit カスタムモデルを使用してみた

Posted at

はじめに

業務では主にAndroid、iOSをやっています

今回、ML kitを触って見たかったので、馬体で画像分類モデルを作成して配布する事を目標にやって見ました。

Labelは下記の2種類です。

  • place:重賞で3着以内に入った馬の馬体
  • other:重賞で着外だった馬の馬体

重賞で3着以内に入る馬は、馬体の完成度が高く特徴がある、という前提で行きます。

スクレイピング

 seleniumで馬の画像収集

今回は、あるサイトで2004~2017までの重賞出走馬の画像がまとまっていたので、これを取得しました。

    implementation 'org.seleniumhq.selenium:htmlunit-driver:2.52.0'
    val driver = HtmlUnitDriver(false)

    (2017 downTo 200).forEach { year ->
        // xpathでゴニョゴニョ
               ....

        ImageIO.write(image, "jpg", File("./$year/$raceName/$horceName.jpg")
    }

取得した結果はこんな感じです。
スクリーンショット 2018-12-17 1.11.37.png

Label毎にフォルダに振り分けます。
手動でやったら思ったより時間がかかったので、とりあえず今回は2017年の画像のみでやりました。

batai
 - image
  - place
   重賞で3着以内に入った馬の馬体画像
  - other
    重賞で着外だった馬の馬体画像

TensorFlow

TensorFlowでModelを作成しconverterでTensorFlow Lite Modelに変換します。

tflite-architecture.jpg

TensorFlow Modelの作成

image_retrainingのtutorialを参考にやって行きます。
https://www.tensorflow.org/hub/tutorials/image_retraining

retrain.pyを使用してretrained_graph.pbとretrained_labels.txtを作成します。

curl -LO https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py
python retrain.py --image_dir $HOME/batai/image  \
--architecture mobilenet_1.0_224  \
--output_graph=$HOME/batai/retrained_graph.pb \
--output_labels=$HOME/batai/retrained_labels.txt \

確認したい場合はlabel_imagew.pyを使用します。

curl -LO https://github.com/tensorflow/tensorflow/raw/master/tensorflow/examples/label_image/label_image.py
python label_image.py \
  --graph=$HOME/batai/retrained_graph.pb  \
  --labels=$HOME/batai/retrained_labels.txt \
  --input_layer=Placeholder \
  --output_layer=final_result \
  --image=$HOME/batai/selenium/image/2015/有馬記念/ゴールドアクター.jpg
place 0.77323943
other 0.22676052

python label_image.py \
  --graph=$HOME/batai/retrained_graph.pb  \
  --labels=$HOME/batai/retrained_labels.txt \
  --input_layer=Placeholder \
  --output_layer=final_result \
  --image=$HOME/batai/selenium/image/2015/有馬記念/リアファル.jpg
other 0.58568174
place 0.41431832

※ゴールドアクター -> 2015年の有馬記念1着
※リアファル -> 2015年の有馬記念16着

TensorFlow Lite

cmdline_examplesを参考にretrained_graph.pbをTensorFlow Liteにconvertしていきます。

tflite_convertコマンドでretrained_graph.tfliteを作成します。

tflite_convert \
  --graph_def_file=$HOME/batai/retrained_graph.pb \
  --output_file=$HOME/batai/retrained_graph.tflite \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --input_shape=1,299,299,3 \
  --input_array=Placeholder \
  --output_array=final_result \
  --input_data_type=FLOAT \
  --default_ranges_min=0  \
  --default_ranges_max=6  \
  --inference_type=QUANTIZED_UINT8  \
  --inference_input_type=QUANTIZED_UINT8  \
  --mean_values=128 \
  --std_dev_values=128 \

コンバート前と後で約4分の1くらいのファイルサイズになりました。

87498466 retrained_graph.pb
21894208 retrained_graph.tflite

Firebase ML Kit

ML kitのドキュメントを参考に進めていきます
https://firebase.google.com/docs/ml-kit/android/use-custom-models?hl=ja

Firebaseの設定

ML kit カスタムタブからカスタムモデルを追加します。

スクリーンショット 2018-12-18 4.20.20.png
スクリーンショット 2018-12-18 4.21.17.png

作成したretrained_graph.tfliteを公開します。
スクリーンショット 2018-12-18 4.23.08.png

アプリの設定

こちらのサンプルアプリを使用させていただきました。
https://github.com/googlecodelabs/tensorflow-for-poets-2/tree/master/android/tflite

ダウンロードconditionsの設定をします。
requireWifi、requireCharging、requireDeviceIdleと状況によって使い分けられそうです。

    FirebaseModelDownloadConditions.Builder conditionsBuilder =
            new FirebaseModelDownloadConditions.Builder().requireWifi();
    if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
      // Enable advanced conditions on Android Nougat and newer.
      conditionsBuilder = conditionsBuilder
              .requireCharging()
              .requireDeviceIdle();
    }
    FirebaseModelDownloadConditions conditions = conditionsBuilder.build();

// Build a FirebaseCloudModelSource object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
    FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("batai")
            .enableModelUpdates(true)
            .setInitialDownloadConditions(conditions)
            .setUpdatesDownloadConditions(conditions)
            .build();
    FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource);

ローカルのバンドル設定をします。
retrained_graph_local.tfliteをassetsに配置します。


    FirebaseLocalModelSource localSource =
            new FirebaseLocalModelSource.Builder("batai_local")  // Assign a name for this model
                    .setAssetFilePath("retrained_graph_local.tflite")
                    .build();
    FirebaseModelManager.getInstance().registerLocalModelSource(localSource);

FirebaseModelInterpreterの設定をします。


        FirebaseModelOptions options = new FirebaseModelOptions.Builder()
                .setCloudModelName("batai")
                .setLocalModelName("batai_local")
                .build();
        FirebaseModelInterpreter firebaseInterpreter =
                FirebaseModelInterpreter.getInstance(options);

入力データの指定を行います。


        FirebaseModelInputOutputOptions inputOutputOptions =
                new FirebaseModelInputOutputOptions.Builder()
                        .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 2})
                        .build();
        Bitmap bitmap =
                textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y);

        int batchNum = 0;
        float[][][][] input = new float[1][224][224][3];
        for (int x = 0; x < 224; x++) {
            for (int y = 0; y < 224; y++) {
                int pixel = bitmap.getPixel(x, y);
                // Normalize channel values to [0.0, 1.0]. This requirement varies by
                // model. For example, some models might require values to be normalized
                // to the range [-1.0, 1.0] instead.
                input[batchNum][x][y][0] = ((Color.red(pixel) - 128) / 128);
                input[batchNum][x][y][1] = ((Color.green(pixel) - 128) / 128);
                input[batchNum][x][y][2] = ((Color.blue(pixel) - 128) / 128);
            }
        }

        FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
                .add(input)  // add() as many input arrays as your model requires
                .build();


        firebaseInterpreter.run(inputs, inputOutputOptions)
                .addOnSuccessListener(
                        new OnSuccessListener<FirebaseModelOutputs>() {
                            @Override
                            public void onSuccess(FirebaseModelOutputs result) {
                                // ...
                                float[][] output = result.getOutput(0);
                                float[] probabilities = output[0];
                                BufferedReader reader = null;
                                try {
                                    reader = new BufferedReader(
                                            new InputStreamReader(getActivity().getAssets().open("output_labels.txt")));
                                    StringBuilder builder = new StringBuilder();
                                    for (int i = 0; i < probabilities.length; i++) {
                                        String label = reader.readLine();
                                        Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
                                        builder.append(String.format("%s: %1.4f", label, probabilities[i]));
                                        builder.append(System.lineSeparator());
                                    }
                                    showToast(builder.toString());
                                } catch (IOException e) {
                                    e.printStackTrace();
                                }

                            }
                        })
                .addOnFailureListener(
                        new OnFailureListener() {
                            @Override
                            public void onFailure(@NonNull Exception e) {
                                // Task failed with an exception
                                // ...
                                e.printStackTrace();
                            }
                        });

ただ、このモデルを使用すると下記が発生しました。

Cannot convert between a TensorFlowLite tensor with type UINT8 and a Java object of type [[[[F (which is compatible with the TensorFlowLite type FLOAT32).

tflite_convertでFLOATを指定すると出なくなりますが、ファイルサイズが大きくなります。(retrained_graph_float.tflite)

tflite_convert \
  --graph_def_file=$HOME/batai/retrained_graph.pb \
  --output_file=$HOME/batai/retrained_graph_float.tflite \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --input_shape=1,299,299,3 \
  --input_array=Placeholder \
  --output_array=final_result \
  --input_data_type=FLOAT \
  --default_ranges_min=0  \
  --default_ranges_max=6  \
  --inference_type=FLOAT  \
  --inference_input_type=FLOAT  \
  --mean_values=128 \
  --std_dev_values=128 \
87498466 retrained_graph.pb
21894208 retrained_graph.tflite
87144108 retrained_graph_float.tflite

ML kitにretrained_graph_float.tfliteを上げようとするとサイズ制限でエラーとなります。
スクリーンショット 2018-12-18 4.50.14.png

この辺りを今回は解決できませんでした。

朝日杯の予想

Float版をアプリにローカルバンドルして試したのですが、実用化は難しいですね。

  • 値が全然安定しない
  • そもそも学習に使っている画像が左向きの画像のみなので、右向かれると対応できない
  • 画像に馬体以外のものが写ってしまう
  • 短距離レースと長距離レースが得意な馬とでは体型が違うことが予想される(競馬場別、年齢別も)

機械学習に関する知識をもっとつけていきたいと思いました。

8
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
6