1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

KenmaroAdvent Calendar 2022

Day 12

Kotlin + TensorFlow Lite でセマンティックセグメンテーションをオープンデータで実行

Last updated at Posted at 2022-12-11

概要

Kotlin を使って、Android 端末(私はPixel4Lを使用)でセマンティックセグメンテーションのモデルを実行してみたかったので、オープンデータでモックアップを作ってみました。

オープンデータとしてはCityScape のデータセットなどを用いています。

その時に少し難しかったところなどをまとめたいと思います。

Kotlin で機械学習推論を行う時の構成

私が今回実装した構成のポイントを列挙していきます。

モデルの作成

tensor-flow keras を用いて、セマンティックセグメンテーションを実行するモデルを構築しました。
調査したアルゴリズムは、比較的古めのUNetや、その他アテンションなどを用いている最近のモデルなどを

ここをみながら既存のコードや、自分でカスタマイズを加えながら実装し、学習を実行しました。
今回はモデルの学習自体の解説については割愛します。

コードはこちらに一部ですが上げています。

別記事で、この時に調べたセマンティックセグメンテーションのモデルについてまとめれたらなと思っています。

モデルのエクスポート

tensorflow keras で学習したモデルを、アンドロイドで使用可能にするために、
tflite の形式にエクスポートしました。

プログラムは上のものを使っています。keras から出力されたh5ファイルをそのままtflite 形式へとエクスポートしました。

Kotlinで簡単なモックアップの作成

tflite のモデルがエクスポートできたので、実際にAndroid端末からアプリを介してモデルの推論を実行できるか試してみます。

プログラムはこちらに上げています。

レイアウト

Screen Shot 2022-11-29 at 10.54.06.png

中央上部にあるのは VideoViewで、mp4 ファイルを再生します。
中央中部にあるのは Predict Buttonで、クリックされると再生中のビデオから画像を取得し、
モデルで推論を実行します。

元画像は下部左にある ImageViewに表示され、推論された後の画像は下部右にあるImageViewに表示されます。

MainActivity の解説

あまり長いコードでもないので、一通りの流れを解説してみます。
今回はモックアップなので、画面遷移などは実装せず、MainActivityに全てのプログラムを書きました。

こちらのファイルになります。
要所要所を解説していきます。

MainActivity.kt
class MainActivity : AppCompatActivity() {

    private fun loadModelFile(modelName: String): MappedByteBuffer {
        var fileDescriptor: AssetFileDescriptor = this.assets.openFd(modelName)
        var inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        var fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

}

loadModelFiletfliteモデルをロードするために使われ、ここで出力されたモデルが Interpreterへと渡されます。

MainActivity.kt
    private fun predict_from_bitmap(bitmapInput: Bitmap){
        var mat1 = Mat()
        var mat2 = Mat()
        val ppTimeStart = SystemClock.uptimeMillis()
        bitmapToMat(bitmapInput, mat1)
        Imgproc.resize(mat1, mat2, Size(128.0, 128.0))
        var bitmap = Bitmap.createBitmap(128, 128, Bitmap.Config.ARGB_8888)
        matToBitmap(mat2, bitmap)

まずは入力となっている bitmap をOpencvのオブジェクトとして変換し、リサイズしています。

MainActivity.kt

        val input = ByteBuffer.allocateDirect(128*128*3*4).order(ByteOrder.nativeOrder())
        for (y in 0 until 128) {
            for (x in 0 until 128) {
                val px = bitmap.getPixel(x, y)

                val r = (px shr 16) and 0xFF
                val g = (px shr 8) and 0xFF
                val b = px and 0xFF

                val rf = r / 255f
                val gf = g / 255f
                val bf = b / 255f

                input.putFloat(rf)
                input.putFloat(gf)
                input.putFloat(bf)
            }
        }


その後、モデルが待ち受けているByteBufferの形に変換します。
この時、ByteBufferは32bit の浮動小数(それぞれ4バイト)をリサイズされた画像サイズ(128, 128, 3) の数だけ格納するbuffer で、
待ち受けている入力の形に従って、浮動小数を詰めていきます。
このとき、bitmapが保持しているピクセル値は1ピクセル3バイトの値になっており、その3バイトでR, G, Bのそれぞれ1バイトずつの整数を保有しているので、右シフトなどを使ってそれぞれを取り出してあげる必要がありました。
このあたりが実際の実装としては慣れていなかったので結構難しいところでした。

MainActivity.kt


        val bufferSize = 128 * 128 * java.lang.Float.SIZE / java.lang.Byte.SIZE
        val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
        val ppTimeStop = SystemClock.uptimeMillis()
        val ppTimeDiff = ppTimeStop - ppTimeStart
        Handler(mainLooper).post({
            textView2.setText("PP Time: $ppTimeDiff")
        })
        // Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors.
        interpreter?.run(input, modelOutput)
        val popTimeStart = SystemClock.uptimeMillis()

interpreter?.run のところで推論を実行しています。
この後は、モデルからの出力であるmodelOutputに対して後処理をして、描画できるように格納する必要があります。

MainActivity.kt

        var floatArrayOutput = FloatArray(128*128)

        for (i in 0..127){
            for (j in 0..127){
                floatArrayOutput[i*128 + j] = modelOutput.getFloat((i*128+j)*4)
            }
        }

        val byteBufferParsedResult = ByteBuffer.allocateDirect(128*128*4).order(ByteOrder.nativeOrder())
        byteBufferParsedResult.rewind()

        for(i in 0..127){
            for(j in 0..127){
                byteBufferParsedResult.put(0.toByte()) // B
                byteBufferParsedResult.put(0.toByte()) // G
                if(floatArrayOutput[i*128+j] > 0.3){
                    byteBufferParsedResult.put(255.toByte()) // R
                }
                else{
                    byteBufferParsedResult.put(0.toByte()) // Red
                }
                byteBufferParsedResult.put(255.toByte()) // A
            }
        }
        byteBufferParsedResult.rewind()

ここでは、後処理された結果byteBufferParsedResultがシグモイド関数を通った値になっていることを考慮して、閾値を設けて白黒の2値へと変換しています。

MainActivity.kt

    @RequiresApi(Build.VERSION_CODES.N)
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        binding = ActivityMainBinding.inflate(layoutInflater)
        setContentView(binding.root)

//        val delegate = GpuDelegate(GpuDelegate.Options().setQuantizedModelsAllowed(true)) // DEQUANTIZE not supported
        val delegate = GpuDelegate(GpuDelegate.Options().setQuantizedModelsAllowed(false)) // TRANSPOSE_CONV: Max version supported: 2. Requested version 3.
        val options = Interpreter.Options().addDelegate(delegate).setNumThreads(4)
//        val options = Interpreter.Options().setUseNNAPI(true) // 1200
//        val options = Interpreter.Options().setNumThreads(4) //500
//        val options = Interpreter.Options().setUseNNAPI(true).setNumThreads(4) // 500

//        val options = Interpreter.Options()
        val tmpFile = loadModelFile("test.tflite") //okay
        interpreter = Interpreter(tmpFile, options)


    }


}

ここの interpreterに渡すオプションを変更することで、マルチスレッドで推論を実行したり、
GPUに計算をデリゲートしたりすることができます。この辺はいろいろ試して、実行時間の変化を観察したりしました。

まとめ

今回は、Kotlin + tflite
でセマンティックセグメンテーションのモデルを実行してみたかったので、
オープンデータでモックアップを作ってみました。

モックアップの見た目を作ったり、Modelを学習してtflite に変更したりするのももちろん大変ですが、
Kotlin で実際に推論を実行するための、OpencvのMat, bitmap, ByteBufferなどをいろいろ行き来したり、どのように数値が格納されているのかを理解するのがとても骨が折れました。
この辺は慣れや経験も必要かなと思いました。

いろいろデバイスを変えて実行しようとすると難しいですが、やはりデバイス上で自分の作ったモデルが動くと嬉しいので、これからもいろんなデバイスで実験してみたいなと思います。

今回はこの辺で。

@kenmaro

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?