前回に引き続き,こちらのpytorchmobileのデモを動かしたメモ書きを残す.今回はObject Ditectionとして,yolov5を使用したリアルタイム推論が紹介されていたので,こちらを実行してみる.
android-demoのgit cloneとREADMEで紹介されているyolov5のgit cloneまでは実行しているものとする.
ptlファイルを作成するdocker環境の編集
READMEに従って,yolov5リポジトリから,ptlファイルを生成する. 手元にあった未編集のpytorch/pytorch:1.9.1-cuda11.1-cudnn8-devel
イメージを使用してexport.pyを実行した(GPU環境は使用していない).
元の環境
torch 1.10.0
torchelastic 0.2.0
torchtext 0.11.0
torchvision 0.11.1
そのままrequire.txtに従ってインストールを実行したかったが,いくつか最新バージョンがインストールされる恐れがあり,元のバージョンに更新が入ることを避けるため,手動でいくつかpip installした.
opencvも必要だったため,apt-getでインストール.
docker内の環境で下記を実行
$ pip install pandas
$ pip install opencv-python
$ apt update
$ apt-get install libopencv-dev
$ apt-get install libgl1-mesa-dev
$ apt-get install git
penCVのバージョン確認
$ python -c "import cv2; print(cv2.__version__)"
4.8.0
$ apt list --installed
libopencv-dev/bionic-updates,bionic-security,now 3.2.0+dfsg-4ubuntu0.1 amd64 [installed]
yolov5のclone
READ.MEの指示の通り,Yolov5をダウンロード.
$ git clone https://github.com/ultralytics/yolov5
$ cd yolov5
export.pyの編集
ここで,READMEで指定されていたコードの改変と,yolov5側のコードが異なっており,READMEと違う修正が必要
@try_export
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript')
fl = file.with_suffix('.torchscript.ptl') # 追加
ts = torch.jit.trace(model, im, strict=False)
d = {'shape': im.shape, 'stride': int(max(model.stride)), 'names': model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
optimize_for_mobile(ts)._save_for_lite_interpreter(str(fl), _extra_files=extra_files) # 追加
else:
ts.save(str(f), _extra_files=extra_files)
return f, None
export.pyを実行
実行する際には,--optimize
オプションをつける.
$ python export.py --weights yolov5s.pt --optimize --include torchscript
上記export.py実行後yolov5/
配下に生成されたyolov5s.torchscript.ptl
を,ObjectDitection/app/src/main/asset/
配下にコピー
ちなみに他サイズのモデルも,同様の方法で生成可能.ただし,javaコード側でモデルの名前が直接指定されているため,sサイズのモデルしか指定できないので注意.
$ python export.py --weights yolov5m.pt --optimize --include torchscript
$ python export.py --weights yolov5l.pt --optimize --include torchscript
$ python export.py --weights yolov5x.pt --optimize --include torchscript
androidプロジェクト側の修正
前回と同様に各種gradle.buildファイル,manifestファイルを修正.
- gradle.build(ObjectDitection)
buildscript {
repositories {
google()
//jcenter()
mavenCentral()
}
dependencies {
classpath 'com.android.tools.build:gradle:4.2.0'
// classpath 'com.android.tools.build:gradle:8.1.1'
// classpath 'com.android.tools.build:gradle:3.2.0-alpha13'
}
}
- gradle.build(App)
apply plugin: 'com.android.application'
android {
compileSdkVersion 31
buildToolsVersion "30.0.2"
defaultConfig {
applicationId "org.pytorch.demo.objectdetection"
minSdkVersion 28
targetSdkVersion 31
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"])
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
def camerax_version = "1.0.0-alpha05"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
implementation 'org.pytorch:pytorch_android_lite:1.11'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.11'
}
エラー対応
- プロジェクトのmakeはうまくいくものの,気になる記述が出る.実行不能エラーというわけではないので,一旦無視する.
注意:~/android-image-rec-pytorch/docker/opt/android-demo-app/ObjectDetection/app/src/main/java/org/pytorch/demo/objectdetection/MainActivity.javaは推奨されないAPIを使用またはオーバーライドしています。
- 実行すると,アプリが開かない.ログストリームで確認すると下記のエラーを吐いている
PytorchStreamReader failed locating file bytecode.pkl: file not found ()
上記エラーはptlファイルが量子化されていないことが原因.モデルが_save_for_lite_interpreter()
で保存されているか確認する.