#今回やること
python で線形回帰モデルを作ってそのモデルを使ってアンドロイド上で推論する。(アンドロイド上で学習させるわけではありません。)
今回のコードはgithubに載せているので適宜参照してください。(最下部にURL掲載)
・PyTorch Mobileを使う
#モデルの作成
まずはアンドロイドで動かすための線形モデルを作っていく。
python環境がなくアンドロイドの方だけ読みたい方はアンドロイドで推論という見出しまで読み飛ばして、完成したモデルをダウンロードしてください。
なお今回掲載するコードはjupyter notebook上で動かしたものです。
##データセット
今回使うデータセットはkaggleに載ってた Red Wine Qualityを使ってみる。
酸味、ph、度数などのワインの成分データからワインの10段階の品質を予測する感じ。
今回は単に線形モデルをアンドロイドで動かしてみたいだけなのでシンプルな線形重回帰で10段階クオリティを連続値とみて線形モデルでフィッティングしていく。11カラムあるけど特にL1正則化とかは無しで。(うーん、精度悪くなりそう...)
##データ整理
データを眺めたり、データの欠損地チェックや、データの整理を行う。
kaggle からダウンロードしたデータのインポート
import torch
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
wineQualityData = pd.read_csv('datas/winequality-red.csv')
一応相関をプロットしたり、欠損チェックしたり..
sns.pairplot(wineQualityData)
#欠損データのチェック
wineQualityData.isnull().sum()
特に欠損値とかもないので次にデータローダーを作っていく
##データローダーの作成
#入力と正解ラベル
X = wineQualityData.drop(['quality'], 1)
y = wineQualityData['quality']
#8:2で分ける
X_train = torch.tensor(X.values[0:int(len(X)*0.8)], dtype=torch.float32)
X_test = torch.tensor(X.values[int(len(X)*0.8):len(X)], dtype=torch.float32)
#8:2で分ける
y_train = torch.tensor(y.values[0:int(len(y)*0.8)], dtype=torch.float32)
y_test = torch.tensor(y.values[int(len(y)*0.8):len(y)], dtype=torch.float32)
#データローダー作成
train = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)
test = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(test, batch_size=50, shuffle=False)
pytorchにデータローダーを簡単に作れるメソッドが用意されてて楽。
今回一応テスト用データも作ってますが今回は使いません。
##モデルの作成
続いて線形モデルを作っていく。
from torch import nn, optim
#モデル
model = nn.Linear(in_features = 11, out_features=1, bias=True)
#学習率
lr = 0.01
#2乗誤差
loss_fn=nn.MSELoss()
#損失関数のログ
losses_train= []
#最適化関数
optimizer = optim.Adam(model.parameters(), lr=lr)
##モデルの学習
作成したモデルを学習させる
from tqdm import tqdm
for epoch in tqdm(range(100)):
print("epoch:", epoch)
for x,y in train_loader:
# 前回の勾配をゼロに
optimizer.zero_grad()
# 予測
y_pred = model(x)
# MSE loss とwによる微分を計算
loss = loss_fn(y_pred, y)
if(epoch != 0): #誤差が小さくなったら終了
if abs(losses_train[-1] - loss.item()) < 1e-1:
break
loss.backward()
optimizer.step()
losses_train.append(loss.item())
print("train_loss", loss.item())
##学習結果
損失関数の推移
plt.plot(losses_train)
ちょっとできたモデルを試してみる
for i in range(len(X_test)):
print("推論結果:",model(X_test[i]).data, "正解ラベル:", y_test[i].data)
んん?全然合ってないな。確かにただの線形重回帰だけどこんなに合わないのかなー。
データをもう一回見てみるとクオリティの56%が5だった。つまり、損失を少なくするようにほとんど5の値に収束してしまったのかな。そもそもこういうデータを連続値ラベルとみて線形重回帰するのは厳しかったのか。分類でやった方がよかったかも。
ただ、今回はモデルの精度を求めるのがメインではないので、とりあえずはこれでモデル完成ということにしておく。
もし、今回の精度が悪かった原因がコードのここが悪いよとかわかる方いましたら、コメントで教えてください。
##モデルの保存
アンドロイドにモデルを入れるためにモデルを保存する
import torchvision
model.eval()
#入力テンソルのサイズ
example = torch.rand(1,11)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("wineModel.pt")
うまく実行できると同じフォルダないにptファイルが生成されるはず。
#アンドロイドで推論
読み飛ばした方はgithubから学習済みモデルをダウンロードしてください。
ここから、アンドロイドスタジオを使っていきます。
##依存関係
2020年3月時点
dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
##モデルを入れる
アンドロイドスタジオに先ほどダウンロードまたは作成した学習済みモデル(wineModel.pt)を入れる。
まずはassetフォルダを作る(「resフォルダとか適当な場所を右クリック->新規->フォルダ->assetフォルダ」で作れる)
そこに学習済みモデルをコピペする。
##レイアウト
推論結果を表示するレイアウトをつくる。といってもtextView3個並べただけ。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<TextView
android:id="@+id/result"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Hello World!"
android:textSize="24sp"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toBottomOf="@+id/label" />
<TextView
android:id="@+id/label"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="TextView"
android:textSize="30sp"
app:layout_constraintBottom_toTopOf="@+id/result"
app:layout_constraintEnd_toEndOf="@+id/result"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="@+id/result"
app:layout_constraintTop_toTopOf="parent" />
<TextView
android:id="@+id/textView2"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="赤ワイン品質予測"
android:textSize="30sp"
app:layout_constraintBottom_toTopOf="@+id/label"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
</androidx.constraintlayout.widget.ConstraintLayout>
##推論
モデルをロードして、テンソルを入れて推論
class MainActivity : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
//ワインのテスト用データ
val inputArray = floatArrayOf(7.1f, 0.46f, 0.2f, 1.9f, 0.077f, 28f, 54f, 0.9956f, 3.37f, 0.64f, 10.4f)
//テンソルの生成: 引数(floatArray, テンソルのサイズ)
val inputTensor = Tensor.fromBlob(inputArray, longArrayOf(1,11))
//モデルのロード
val module = Module.load(assetFilePath(this, "wineModel.pt"))
//推論
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
val scores = outputTensor.dataAsFloatArray
//結果の表示
result.text ="予測値: ${scores[0]}"
label.text = "正解ラベル:6"
}
//assetフォルダからパスを取得する関数
fun assetFilePath(context: Context, assetName: String): String {
val file = File(context.filesDir, assetName)
if (file.exists() && file.length() > 0) {
return file.absolutePath
}
context.assets.open(assetName).use { inputStream ->
FileOutputStream(file).use { outputStream ->
val buffer = ByteArray(4 * 1024)
var read: Int
while (inputStream.read(buffer).also { read = it } != -1) {
outputStream.write(buffer, 0, read)
}
outputStream.flush()
}
return file.absolutePath
}
}
}
完成!!
ここまで来て実行すると冒頭の画面が出てくるはず。
#おわり
画像系はチュートリアルでもあるけど、普通の線形とかはあまり載ってなかったのでこの記事を書いてみた。
モデルの精度がイマイチだったのが引っかかる所だけど、一応線形モデルを動かすことができた。
今度は分類をやってみようかな。
今回のコードはこちら
pythonコード
android studio コード