LoginSignup
0
1

androidでYOLOv8物体検出モデルを使う

Last updated at Posted at 2024-04-26

Androidで物体検出

物体検出AIをandroidデバイスで動かせると、サーバー通信なしで便利な物体検出機能を世界中のユーザーに使ってもらえます。

Yolov8はポピュラーな物体検出AIです。
androidは世界一ユーザーの多いモバイルOSです。
本記事はandroidデバイス上でyolov8物体検出を行う方法です。

以下のリポジトリのコードを参考にしています。

↑ androidデバイスのカメラでリアルタイム検出を行うコードです。

僕もシンプルなサンプルを作りましたので、よかったら参照してください。
カメラ機能などを省いた検出部分のみのサンプルです。

手順1:Pytorch形式からtflite形式に変換

YOLOv8はpytorch形式で構築されています。
これをandroidで使えるようにtfliteに変換します。

YOLOv8のインストール

Ultralyticsというフレームワークをインストールします。
Yolov8はこのフレームワークに含まれます。

pip install ultralytics

tfliteに変換

変換コードで変換します。
下記のコードで事前トレーニング済みモデルの重みがダウンロードされます。
もし自前のカスタムデータでトレーニングしたモデルの重みチェックポイントファイルがある場合、yolov8s.ptの部分を置き換えます。

from ultralytics import YOLO
model = YOLO('yolov8s.pt')
model.export(format="tflite")

yolov8s_saved_model/yolov8s_float16.tfliteが生成されるので、これを使います。

変換エラーが出たら。。。

もし以下のエラーが出たらtensorflowのバージョンによるものなので、適合したバージョンをインストールします。

ImportError: generic_type: cannot initialize type "StatusCode": an object with that name is already defined

例えば、tensorflowを以下のバージョンに変更します。

pip install tensorflow==2.13.0

androidでtfliteファイルを実行

ここからは、android studio のプロジェクトでyolov8のtfliteファイルを実行する部分です。

tfliteファイルをプロジェクトに追加

android studioプロジェクトのappディレクトリにassetsディレクトリを作成(File→New→Folder→Asset Folder)し、tfliteファイル(yolov8s_float32.tflite)とlabels.txtを追加します。 コピペで追加できます。

labels.txtはYOLOv8モデルのクラス名が以下のように記述されたテキストファイルです。
カスタムクラスを設定している場合は、そのクラスを記述します。
デフォルトのYOLOv8の事前トレーニング済みモデルだと以下です。

labels.txt
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

tfliteのインストール

app/build.gradle.kts の dependencies に以下を追加して tflite フレームワークをインストールします。

app/build.gradle.kts
implementation("org.tensorflow:tensorflow-lite:2.14.0")
implementation("org.tensorflow:tensorflow-lite-support:0.4.4")

上記を追記したら、Sync Nowを押してインストールします。

必要なモジュールのインポート

import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.CastOp
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.BufferedReader
import java.io.IOException
import java.io.InputStream
import java.io.InputStreamReader

必要なクラス・プロパティ

private val modelPath = "yolov8s_float32.tflite"
private val labelPath = "labels.txt"
private var interpreter: Interpreter? = null

private var tensorWidth = 0
private var tensorHeight = 0
private var numChannel = 0
private var numElements = 0

private var labels = mutableListOf<String>()

private val imageProcessor = ImageProcessor.Builder()
    .add(NormalizeOp(INPUT_MEAN, INPUT_STANDARD_DEVIATION))
    .add(CastOp(INPUT_IMAGE_TYPE))
    .build() // preprocess input

companion object {
    private const val INPUT_MEAN = 0f
    private const val INPUT_STANDARD_DEVIATION = 255f
    private val INPUT_IMAGE_TYPE = DataType.FLOAT32
    private val OUTPUT_IMAGE_TYPE = DataType.FLOAT32
    private const val CONFIDENCE_THRESHOLD = 0.3F
    private const val IOU_THRESHOLD = 0.5F
}

モデルの初期化

tfliteモデルを初期化します。モデルファイルを取得し、tfliteのInterpreter に渡します。
オプションで使用するスレッド数を渡します。
Activityでないクラスで使う場合は、contextをクラスに渡す必要があります。

val model = FileUtil.loadMappedFile(context, modelPath)
val options = Interpreter.Options()
options.numThreads = 4
interpreter = Interpreter(model, options)

Interpreterからyolov8sの入力と出力の形状を取得します。

val inputShape = interpreter.getInputTensor(0).shape()
val outputShape = interpreter.getOutputTensor(0).shape()

tensorWidth = inputShape[1]
tensorHeight = inputShape[2]
numChannel = outputShape[1]
numElements = outputShape[2]

label.txtファイルからクラス名を読み取ります。
InputStreamとInputStreamReaderは明示的に閉じる必要があります。

try {
    val inputStream: InputStream = context.assets.open(labelPath)
    val reader = BufferedReader(InputStreamReader(inputStream))

    var line: String? = reader.readLine()
    while (line != null && line != "") {
        labels.add(line)
        line = reader.readLine()
    }

    reader.close()
    inputStream.close()
} catch (e: IOException) {
    e.printStackTrace()
}

画像を入力して実行する

bitmapを入力としますが、モデルの入力形式に合わせて以下の前処理を行います。

1、モデルの入力形状に合わせてリサイズ
2、tensorにする
3、ピクセル値を255で割って正規化する(0~1の範囲の値にする)
4、モデルの入力タイプにキャストする
5、入力用のimageBufferを取得する

val resizedBitmap = Bitmap.createScaledBitmap(bitmap, tensorWidth, tensorHeight, false)

val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(resizedBitmap)
val processedImage = imageProcessor.process(tensorImage)
val imageBuffer = processedImage.buffer

実行します。
モデルの出力形状に合わせた 出力用の tensor buffer を作成し、
上記の入力 imageBuffer と一緒に interpreter に渡して実行します。

val output = TensorBuffer.createFixedSize(intArrayOf(1 , numChannel, numElements), OUTPUT_IMAGE_TYPE)
interpreter.run(imageBuffer, output.buffer)

出力の後処理を行う

出力ボックスはBoudingBoxクラスとして扱います。
クラス、ボックス、信頼度を持つクラスです。
x1 y1は始点。
x2 y2は終点。
cx cyはセンター。
w は幅。
h は高さ。
です。

data class BoundingBox(
    val x1: Float,
    val y1: Float,
    val x2: Float,
    val y2: Float,
    val cx: Float,
    val cy: Float,
    val w: Float,
    val h: Float,
    val cnf: Float,
    val cls: Int,
    val clsName: String
)

以下の処理で、たくさんある出力ボックス候補の中から、信頼度が高いものを選びます。
1、confidence threshold より信頼度が高いボックスを抽出する。
2、重なっているボックスの中で、一番信頼度が高いボックスを残す。(nms)

private fun bestBox(array: FloatArray) : List<BoundingBox>? {

    val boundingBoxes = mutableListOf<BoundingBox>()

    for (c in 0 until numElements) {
        var maxConf = -1.0f
        var maxIdx = -1
        var j = 4
        var arrayIdx = c + numElements * j
        while (j < numChannel){
            if (array[arrayIdx] > maxConf) {
                maxConf = array[arrayIdx]
                maxIdx = j - 4
            }
            j++
            arrayIdx += numElements
        }

        if (maxConf > CONFIDENCE_THRESHOLD) {
            val clsName = labels[maxIdx]
            val cx = array[c] // 0
            val cy = array[c + numElements] // 1
            val w = array[c + numElements * 2]
            val h = array[c + numElements * 3]
            val x1 = cx - (w/2F)
            val y1 = cy - (h/2F)
            val x2 = cx + (w/2F)
            val y2 = cy + (h/2F)
            if (x1 < 0F || x1 > 1F) continue
            if (y1 < 0F || y1 > 1F) continue
            if (x2 < 0F || x2 > 1F) continue
            if (y2 < 0F || y2 > 1F) continue

            boundingBoxes.add(
                BoundingBox(
                    x1 = x1, y1 = y1, x2 = x2, y2 = y2,
                    cx = cx, cy = cy, w = w, h = h,
                    cnf = maxConf, cls = maxIdx, clsName = clsName
                )
            )
        }
    }

    if (boundingBoxes.isEmpty()) return null

    return applyNMS(boundingBoxes)
}

private fun applyNMS(boxes: List<BoundingBox>) : MutableList<BoundingBox> {
    val sortedBoxes = boxes.sortedByDescending { it.cnf }.toMutableList()
    val selectedBoxes = mutableListOf<BoundingBox>()

    while(sortedBoxes.isNotEmpty()) {
        val first = sortedBoxes.first()
        selectedBoxes.add(first)
        sortedBoxes.remove(first)

        val iterator = sortedBoxes.iterator()
        while (iterator.hasNext()) {
            val nextBox = iterator.next()
            val iou = calculateIoU(first, nextBox)
            if (iou >= IOU_THRESHOLD) {
                iterator.remove()
            }
        }
    }

    return selectedBoxes
}

private fun calculateIoU(box1: BoundingBox, box2: BoundingBox): Float {
    val x1 = maxOf(box1.x1, box2.x1)
    val y1 = maxOf(box1.y1, box2.y1)
    val x2 = minOf(box1.x2, box2.x2)
    val y2 = minOf(box1.y2, box2.y2)
    val intersectionArea = maxOf(0F, x2 - x1) * maxOf(0F, y2 - y1)
    val box1Area = box1.w * box1.h
    val box2Area = box2.w * box2.h
    return intersectionArea / (box1Area + box2Area - intersectionArea)
}

ここまでで、yolov8の出力が得られます。

val bestBoxes = bestBox(output.floatArray)

出力のボックスを画像に描画する

fun drawBoundingBoxes(bitmap: Bitmap, boxes: List<BoundingBox>): Bitmap {
    val mutableBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)
    val canvas = Canvas(mutableBitmap)
    val paint = Paint().apply {
        color = Color.RED
        style = Paint.Style.STROKE
        strokeWidth = 8f
    }
    val textPaint = Paint().apply {
        color = Color.WHITE
        textSize = 40f
        typeface = Typeface.DEFAULT_BOLD
    }

    for (box in boxes) {
        val rect = RectF(
            box.x1 * mutableBitmap.width,
            box.y1 * mutableBitmap.height,
            box.x2 * mutableBitmap.width,
            box.y2 * mutableBitmap.height
        )
        canvas.drawRect(rect, paint)
        canvas.drawText(box.clsName, rect.left, rect.bottom, textPaint)
    }

    return mutableBitmap
}

うまくいかない時

モデルパスが間違っていたりして、interpreterがnullになっていることが多かったので、その辺りをチェックするといいかもです。

🐣


フリーランスエンジニアです。
AIについて色々記事を書いていますのでよかったらプロフィールを見てみてください。

もし以下のようなご要望をお持ちでしたらお気軽にご相談ください。
AIサービスを開発したい、ビジネスにAIを組み込んで効率化したい、AIを使ったスマホアプリを開発したい、
ARを使ったアプリケーションを作りたい、スマホアプリを作りたいけどどこに相談したらいいかわからない…

いずれも中間コストを省いたリーズナブルな価格でお請けできます。

お仕事のご相談はこちらまで
rockyshikoku@gmail.com

機械学習やAR技術を使ったアプリケーションを作っています。
機械学習/AR関連の情報を発信しています。

X
Medium
GitHub

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1