5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Androidで画像解析をしてみる(Pytorch MobileとKotlinを利用)

Last updated at Posted at 2020-11-03

Pytorch Mobileを利用した画像解析アプリケーション

  • resnet
  • pytorch mobile
  • camerax

を利用して、撮影した画像に対して画像解析を行うアプリケーションを作成しました。

起動すると、カメラプレビューが表示されます。
「撮影」ボタンを押すと、プレビューされている画像を撮影でき、
「解析」ボタンを押すと、学習済みモデルに画像を入力し、何が写っているか判定を行います。
「解析結果」の横に結果が表示されます。

1.起動時

2.写真撮影

3.画像解析

環境構築

pytorchのインストール
# 最新versionだとPytorch Mobileがうまく動作しない可能性あり
conda install pytorch==1.4.0 torchvision==0.5.0 -c pytorch

# ちゃんと動くか確認します
python
  # [入力] 1行ずつ入力
  from __future__ import  print_function
  import torch
  x = torch.rand(5, 3)
  x
  
  # [出力]
  tensor([[0.3380, 0.3845, 0.3217],
        [0.8337, 0.9050, 0.2650],
        [0.2979, 0.7141, 0.9069],
        [0.1449, 0.1132, 0.1375],
        [0.4675, 0.3947, 0.1426]])

※Android Studioについてはインストール方法を詳しく説明している記事が他に多数あるので説明を省きます

Pythonモデル構築面

今回は画像分類の学習済みモデルresnet18を利用するので、それをダウンロードしてAndroid側で読み込めるようにするだけです。
本来であれば、何かしらモデルを構築しておく手順をここで行います。

今回は以下のようなPythonスクリプトをプロジェクト内に作成しました。
これを実行することで、プロジェクトのassetsに学習済みモデルが保存されます。

createModel.py
import torch
import torchvision

# resnetモデルを利用する
model = torchvision.models.resnet18(pretrained=True)
# 推論モードにする
model.eval()

# サンプル入力を与える
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
# モデルを保存
traced_script_module.save("app/src/main/assets/resnet.pt")

Android側で、撮影した画像をresnet.ptに入力することで推論を行うことができる。といった流れです。

Android実装面

プロジェクトの作成

NewProject-->Empty Activity-->LanguageをKotlinにしてFinish

ビルド依存関係の追加とオプション

ライブラリを取り込む

build.gradle
dependencies {
...

  // camerax関連
  def camerax_version = "1.0.0-beta07"
  // CameraX core library using camera2 implementation
  implementation "androidx.camera:camera-camera2:$camerax_version"
  // CameraX Lifecycle Library
  implementation "androidx.camera:camera-lifecycle:$camerax_version"
  // CameraX View class
  implementation "androidx.camera:camera-view:1.0.0-alpha14"

  // pytorch関連
  implementation 'org.pytorch:pytorch_android:1.4.0'
  implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}

Java version1.8でビルドするように指定

build.gradle
android {
    compileSdkVersion 29

    ...
    
    ...
    
    // 以下を追記
    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
}

CameraXによる写真撮影

CameraXはAndroidで利用できるカメラ関連のライブラリであり、イメージのプレビューや写真の撮影、画像解析などを行える。
今回のソースコードはほぼGetting Started with CameraX  |  Google Codelabsのチュートリアル通りです。

カメラの権限付与

AndroidManifest.xml
<uses-feature android:name="android.hardware.camera.any" />
<uses-permission android:name="android.permission.CAMERA" />

<application
...

画面の作成

撮影ボタン・解析開始ボタン・イメージプレビュー・撮影した画像の表示エリア・解析結果表示エリアなどを配置します。

activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
    xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">


    <Button
        android:id="@+id/camera_capture_button"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="撮影"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.261"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/frameLayout" />


    <androidx.camera.view.PreviewView
        android:id="@+id/viewFinder"
        android:layout_width="207dp"
        android:layout_height="220dp"
        android:layout_marginTop="16dp"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.503"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent"
        app:srcCompat="@mipmap/ic_launcher_round" />

    <FrameLayout
        android:id="@+id/frameLayout"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginTop="8dp"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/viewFinder">

        <ImageView
            android:id="@+id/capturedImg"
            android:layout_width="207dp"
            android:layout_height="215dp" />

    </FrameLayout>

    <Button
        android:id="@+id/inferBtn"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginStart="32dp"
        android:text="解析"
        app:layout_constraintBottom_toBottomOf="@+id/camera_capture_button"
        app:layout_constraintStart_toEndOf="@+id/camera_capture_button"
        app:layout_constraintTop_toTopOf="@+id/camera_capture_button" />

    <TextView
        android:id="@+id/resultText"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="解析結果"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.278"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/camera_capture_button"
        android:layout_marginTop="5dp"/>


</androidx.constraintlayout.widget.ConstraintLayout>

こんな感じになります。

イメージプレビュー・撮影した画像の確認を動作させる

MainActivity.ktを編集して機能を追加していきます。
長くなるのでソースコード全体は載せません。
ソースコードはgithubにアップロードしているので、そちらで確認してください。
ただ、cameraxのチュートリアル通りなので、一度チュートリアルをこなすといいかなと思います。

MainActivity.ktのonCreateの部分
// Activity作成時に一度だけ呼び出されるやつ
override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        // カメラの権限チェックをして、権限があればカメラ起動処理を呼び出す
        if (allPermissionsGranted()) {
            startCamera()
        } else {
            ActivityCompat.requestPermissions(
                this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS)
        }

        // 撮影ボタンにリスナーを追加する
        camera_capture_button.setOnClickListener { takePhoto() }
        
        // 写真を吐き出すディレクトリの取得
        outputDirectory = getOutputDirectory()
        
        // カメラを別スレッドで動かす
        cameraExecutor = Executors.newSingleThreadExecutor()
    }

OnCreateで呼び出しているtakePhoto()などを追加してビルドすると、以下のような感じになります。

解析処理の実装

ここまでの作業で写真を撮影できるようになったので、撮影した写真に対して画像解析を適用する処理を追加していきます。

Androidプロジェクトへのモデルの追加

まずはassetsフォルダを作成します。
Android Studioの左側にあるファイル群のペインを右クリックして、New-->Folder-->asset folder-->finishをすると、assetsフォルダが生成されます。

次に、モデルをダウンロードしてassetsフォルダに保存するpythonスクリプトを実行します。
私はpythonスクリプトをappフォルダ, gradleフォルダなどがある階層の一つ上に配置しました。

createModel.py
import torch
import torchvision

# resnetモデルを利用する
model = torchvision.models.resnet18(pretrained=True)
# 推論モードにする
model.eval()

# サンプル入力を与える
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
# モデルを保存
traced_script_module.save("app/src/main/assets/resnet.pt")

これを実行して、assetsフォルダにresnet.ptが追加されていればOKです。

解析処理をMainAcivityに追加

まずは、解析ボタンを押すと解析処理を開始するようにOnCreate()にリスナーを追加します。

MainActivity.ktのOnCreateの一部
override fun onCreate(savedInstanceState: Bundle?){
...

  // 撮影ボタンにリスナーを追加する
  camera_capture_button.setOnClickListener { takePhoto() }
  // 解析開始ボタンにリスナーを追加する
  inferBtn.setOnClickListener { capture_analyze() }
  
  // capture_analyzeが解析処理に対応する関数です。

実際の解析処理を書いていきます。
解析処理といっても、撮影した画像をモデルに入力して結果を得るだけです。

MainActivity.ktのcapture_analyze()
private fun capture_analyze(){
        // resnetモデルを取得
        val resnet = Module.load(getAssetFilePath(this, "resnet.pt"))

        //撮影した写真をリサイズ
        val imgDataResized = Bitmap.createScaledBitmap(imgData!!, 224, 224, true)
        // bitmapをtensorに変換
        val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
            imgDataResized,
            TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
            TensorImageUtils.TORCHVISION_NORM_STD_RGB
        )

        //フォワードプロパゲーション
        val outputTensor = resnet.forward(IValue.from(inputTensor)).toTensor()
        val scores = outputTensor.dataAsFloatArray

        var maxScore = 0F
        var maxScoreIdx = 0
        for (i in scores.indices) {
            if (scores[i] > maxScore) {
                maxScore = scores[i]
                maxScoreIdx = i
            }
        }

        // 一番近いスコアを持つカテゴリを取得してそれを"解析結果"の部分に表示させる
        val inferCategory = ImageNetCategory().IMAGENET_CLASSES[maxScoreIdx]
        resultText.text = "解析結果:${inferCategory}"
    }

最後の部分のカテゴリ取得については、ImageNetCategory.ktというクラスを作って、そこから取得してくるようにしています。
そのため、ImageNetCategory.ktを作成します。

MainActivity.ktがある階層に、ImageNetCategory.ktを作成します。
githubからコピーしてください。

ImageNetCategory.kt
class ImageNetCategory {
    var IMAGENET_CLASSES = arrayOf(
        "tench, Tinca tinca",
        "goldfish, Carassius auratus",
        "great white shark, white shark, man-eat
        ...

終わりに

モバイル端末での画像解析は初めてしました。
pytorch mobile便利そう。色々応用できそうなので今後も頑張ります。

参考サイト

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?