Pytorch Mobileを利用した画像解析アプリケーション
- resnet
- pytorch mobile
- camerax
を利用して、撮影した画像に対して画像解析を行うアプリケーションを作成しました。
起動すると、カメラプレビューが表示されます。
「撮影」ボタンを押すと、プレビューされている画像を撮影でき、
「解析」ボタンを押すと、学習済みモデルに画像を入力し、何が写っているか判定を行います。
「解析結果」の横に結果が表示されます。
環境構築
- 必要なモノ
# 最新versionだとPytorch Mobileがうまく動作しない可能性あり
conda install pytorch==1.4.0 torchvision==0.5.0 -c pytorch
# ちゃんと動くか確認します
python
# [入力] 1行ずつ入力
from __future__ import print_function
import torch
x = torch.rand(5, 3)
x
# [出力]
tensor([[0.3380, 0.3845, 0.3217],
[0.8337, 0.9050, 0.2650],
[0.2979, 0.7141, 0.9069],
[0.1449, 0.1132, 0.1375],
[0.4675, 0.3947, 0.1426]])
※Android Studioについてはインストール方法を詳しく説明している記事が他に多数あるので説明を省きます
Pythonモデル構築面
今回は画像分類の学習済みモデルresnet18を利用するので、それをダウンロードしてAndroid側で読み込めるようにするだけです。
本来であれば、何かしらモデルを構築しておく手順をここで行います。
今回は以下のようなPythonスクリプトをプロジェクト内に作成しました。
これを実行することで、プロジェクトのassetsに学習済みモデルが保存されます。
import torch
import torchvision
# resnetモデルを利用する
model = torchvision.models.resnet18(pretrained=True)
# 推論モードにする
model.eval()
# サンプル入力を与える
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
# モデルを保存
traced_script_module.save("app/src/main/assets/resnet.pt")
Android側で、撮影した画像をresnet.ptに入力することで推論を行うことができる。といった流れです。
Android実装面
プロジェクトの作成
NewProject-->Empty Activity-->LanguageをKotlinにしてFinish
ビルド依存関係の追加とオプション
ライブラリを取り込む
dependencies {
...
// camerax関連
def camerax_version = "1.0.0-beta07"
// CameraX core library using camera2 implementation
implementation "androidx.camera:camera-camera2:$camerax_version"
// CameraX Lifecycle Library
implementation "androidx.camera:camera-lifecycle:$camerax_version"
// CameraX View class
implementation "androidx.camera:camera-view:1.0.0-alpha14"
// pytorch関連
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
Java version1.8でビルドするように指定
android {
compileSdkVersion 29
...
...
// 以下を追記
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
CameraXによる写真撮影
CameraXはAndroidで利用できるカメラ関連のライブラリであり、イメージのプレビューや写真の撮影、画像解析などを行える。
今回のソースコードはほぼGetting Started with CameraX | Google Codelabsのチュートリアル通りです。
カメラの権限付与
<uses-feature android:name="android.hardware.camera.any" />
<uses-permission android:name="android.permission.CAMERA" />
<application
...
画面の作成
撮影ボタン・解析開始ボタン・イメージプレビュー・撮影した画像の表示エリア・解析結果表示エリアなどを配置します。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
xmlns:app="http://schemas.android.com/apk/res-auto"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<Button
android:id="@+id/camera_capture_button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="撮影"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.261"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/frameLayout" />
<androidx.camera.view.PreviewView
android:id="@+id/viewFinder"
android:layout_width="207dp"
android:layout_height="220dp"
android:layout_marginTop="16dp"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.503"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent"
app:srcCompat="@mipmap/ic_launcher_round" />
<FrameLayout
android:id="@+id/frameLayout"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="8dp"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/viewFinder">
<ImageView
android:id="@+id/capturedImg"
android:layout_width="207dp"
android:layout_height="215dp" />
</FrameLayout>
<Button
android:id="@+id/inferBtn"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginStart="32dp"
android:text="解析"
app:layout_constraintBottom_toBottomOf="@+id/camera_capture_button"
app:layout_constraintStart_toEndOf="@+id/camera_capture_button"
app:layout_constraintTop_toTopOf="@+id/camera_capture_button" />
<TextView
android:id="@+id/resultText"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="解析結果"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.278"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/camera_capture_button"
android:layout_marginTop="5dp"/>
</androidx.constraintlayout.widget.ConstraintLayout>
こんな感じになります。
イメージプレビュー・撮影した画像の確認を動作させる
MainActivity.ktを編集して機能を追加していきます。
長くなるのでソースコード全体は載せません。
ソースコードはgithubにアップロードしているので、そちらで確認してください。
ただ、cameraxのチュートリアル通りなので、一度チュートリアルをこなすといいかなと思います。
// Activity作成時に一度だけ呼び出されるやつ
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
// カメラの権限チェックをして、権限があればカメラ起動処理を呼び出す
if (allPermissionsGranted()) {
startCamera()
} else {
ActivityCompat.requestPermissions(
this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS)
}
// 撮影ボタンにリスナーを追加する
camera_capture_button.setOnClickListener { takePhoto() }
// 写真を吐き出すディレクトリの取得
outputDirectory = getOutputDirectory()
// カメラを別スレッドで動かす
cameraExecutor = Executors.newSingleThreadExecutor()
}
OnCreateで呼び出しているtakePhoto()などを追加してビルドすると、以下のような感じになります。
解析処理の実装
ここまでの作業で写真を撮影できるようになったので、撮影した写真に対して画像解析を適用する処理を追加していきます。
Androidプロジェクトへのモデルの追加
まずはassetsフォルダを作成します。
Android Studioの左側にあるファイル群のペインを右クリックして、New-->Folder-->asset folder-->finish
をすると、assetsフォルダが生成されます。
次に、モデルをダウンロードしてassetsフォルダに保存するpythonスクリプトを実行します。
私はpythonスクリプトをappフォルダ, gradleフォルダなどがある階層の一つ上に配置しました。
import torch
import torchvision
# resnetモデルを利用する
model = torchvision.models.resnet18(pretrained=True)
# 推論モードにする
model.eval()
# サンプル入力を与える
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
# モデルを保存
traced_script_module.save("app/src/main/assets/resnet.pt")
これを実行して、assetsフォルダにresnet.pt
が追加されていればOKです。
解析処理をMainAcivityに追加
まずは、解析ボタンを押すと解析処理を開始するようにOnCreate()にリスナーを追加します。
override fun onCreate(savedInstanceState: Bundle?){
...
// 撮影ボタンにリスナーを追加する
camera_capture_button.setOnClickListener { takePhoto() }
// 解析開始ボタンにリスナーを追加する
inferBtn.setOnClickListener { capture_analyze() }
// capture_analyzeが解析処理に対応する関数です。
実際の解析処理を書いていきます。
解析処理といっても、撮影した画像をモデルに入力して結果を得るだけです。
private fun capture_analyze(){
// resnetモデルを取得
val resnet = Module.load(getAssetFilePath(this, "resnet.pt"))
//撮影した写真をリサイズ
val imgDataResized = Bitmap.createScaledBitmap(imgData!!, 224, 224, true)
// bitmapをtensorに変換
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
imgDataResized,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
)
//フォワードプロパゲーション
val outputTensor = resnet.forward(IValue.from(inputTensor)).toTensor()
val scores = outputTensor.dataAsFloatArray
var maxScore = 0F
var maxScoreIdx = 0
for (i in scores.indices) {
if (scores[i] > maxScore) {
maxScore = scores[i]
maxScoreIdx = i
}
}
// 一番近いスコアを持つカテゴリを取得してそれを"解析結果"の部分に表示させる
val inferCategory = ImageNetCategory().IMAGENET_CLASSES[maxScoreIdx]
resultText.text = "解析結果:${inferCategory}"
}
最後の部分のカテゴリ取得については、ImageNetCategory.kt
というクラスを作って、そこから取得してくるようにしています。
そのため、ImageNetCategory.kt
を作成します。
MainActivity.ktがある階層に、ImageNetCategory.ktを作成します。
githubからコピーしてください。
class ImageNetCategory {
var IMAGENET_CLASSES = arrayOf(
"tench, Tinca tinca",
"goldfish, Carassius auratus",
"great white shark, white shark, man-eat
...
終わりに
モバイル端末での画像解析は初めてしました。
pytorch mobile便利そう。色々応用できそうなので今後も頑張ります。