LoginSignup
0
1

More than 3 years have passed since last update.

アンドロイドで線形回帰モデルを使って推論してみる[PyTorch Mobile]

Posted at

今回やること

python で線形回帰モデルを作ってそのモデルを使ってアンドロイド上で推論する。(アンドロイド上で学習させるわけではありません。)

今回のコードはgithubに載せているので適宜参照してください。(最下部にURL掲載)

今回作るやつ↓

PyTorch Mobileを使う

pytorch-mobile.png

モデルの作成

まずはアンドロイドで動かすための線形モデルを作っていく。
python環境がなくアンドロイドの方だけ読みたい方はアンドロイドで推論という見出しまで読み飛ばして、完成したモデルをダウンロードしてください。

なお今回掲載するコードはjupyter notebook上で動かしたものです。

データセット

今回使うデータセットはkaggleに載ってた Red Wine Qualityを使ってみる。

酸味、ph、度数などのワインの成分データからワインの10段階の品質を予測する感じ。
キャプチdfasdfadsffdvzscャ.PNG

今回は単に線形モデルをアンドロイドで動かしてみたいだけなのでシンプルな線形重回帰で10段階クオリティを連続値とみて線形モデルでフィッティングしていく。11カラムあるけど特にL1正則化とかは無しで。(うーん、精度悪くなりそう...)

データ整理

データを眺めたり、データの欠損地チェックや、データの整理を行う。

kaggle からダウンロードしたデータのインポート

import torch 
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns

wineQualityData = pd.read_csv('datas/winequality-red.csv')

一応相関をプロットしたり、欠損チェックしたり..

sns.pairplot(wineQualityData)

#欠損データのチェック
wineQualityData.isnull().sum()

cvahtr.PNG   yjktkj.PNG

特に欠損値とかもないので次にデータローダーを作っていく

データローダーの作成

#入力と正解ラベル
X = wineQualityData.drop(['quality'], 1)
y = wineQualityData['quality']

#8:2で分ける
X_train = torch.tensor(X.values[0:int(len(X)*0.8)], dtype=torch.float32)
X_test = torch.tensor(X.values[int(len(X)*0.8):len(X)], dtype=torch.float32)

#8:2で分ける
y_train = torch.tensor(y.values[0:int(len(y)*0.8)], dtype=torch.float32)
y_test = torch.tensor(y.values[int(len(y)*0.8):len(y)], dtype=torch.float32)

#データローダー作成
train = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)
test = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(test, batch_size=50, shuffle=False)

pytorchにデータローダーを簡単に作れるメソッドが用意されてて楽。
今回一応テスト用データも作ってますが今回は使いません。

モデルの作成

続いて線形モデルを作っていく。

from torch import nn, optim

#モデル
model = nn.Linear(in_features = 11, out_features=1, bias=True)
#学習率
lr = 0.01
#2乗誤差
loss_fn=nn.MSELoss()
#損失関数のログ
losses_train= []
#最適化関数
optimizer = optim.Adam(model.parameters(), lr=lr)

モデルの学習

作成したモデルを学習させる

from tqdm import tqdm
for epoch in tqdm(range(100)):
    print("epoch:", epoch)

    for x,y in train_loader:
        # 前回の勾配をゼロに
        optimizer.zero_grad()
        # 予測
        y_pred = model(x)
        # MSE loss とwによる微分を計算
        loss = loss_fn(y_pred, y)
        if(epoch != 0):  #誤差が小さくなったら終了
            if abs(losses_train[-1] - loss.item()) < 1e-1:
                break
        loss.backward()
        optimizer.step()
        losses_train.append(loss.item())

    print("train_loss", loss.item())

学習結果

損失関数の推移

plt.plot(losses_train)

adsffgda.png
一応収束はしてるっぽい。

ちょっとできたモデルを試してみる

for i in range(len(X_test)):
    print("推論結果:",model(X_test[i]).data, "正解ラベル:", y_test[i].data)

adsfjkasdf.PNG

んん?全然合ってないな。確かにただの線形重回帰だけどこんなに合わないのかなー。
データをもう一回見てみるとクオリティの56%が5だった。つまり、損失を少なくするようにほとんど5の値に収束してしまったのかな。そもそもこういうデータを連続値ラベルとみて線形重回帰するのは厳しかったのか。分類でやった方がよかったかも。

ただ、今回はモデルの精度を求めるのがメインではないので、とりあえずはこれでモデル完成ということにしておく。

もし、今回の精度が悪かった原因がコードのここが悪いよとかわかる方いましたら、コメントで教えてください。

モデルの保存

アンドロイドにモデルを入れるためにモデルを保存する

import torchvision

model.eval()
#入力テンソルのサイズ
example = torch.rand(1,11) 
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("wineModel.pt")

うまく実行できると同じフォルダないにptファイルが生成されるはず。

アンドロイドで推論

読み飛ばした方はgithubから学習済みモデルをダウンロードしてください。

ここから、アンドロイドスタジオを使っていきます。

依存関係

2020年3月時点

build.gradle
dependencies {
    implementation 'org.pytorch:pytorch_android:1.4.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}

モデルを入れる

アンドロイドスタジオに先ほどダウンロードまたは作成した学習済みモデル(wineModel.pt)を入れる。

まずはassetフォルダを作る(「resフォルダとか適当な場所を右クリック->新規->フォルダ->assetフォルダ」で作れる)
そこに学習済みモデルをコピペする。
jkjjkl.PNG

レイアウト

推論結果を表示するレイアウトをつくる。といってもtextView3個並べただけ。

activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <TextView
        android:id="@+id/result"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="Hello World!"
        android:textSize="24sp"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintLeft_toLeftOf="parent"
        app:layout_constraintRight_toRightOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/label" />

    <TextView
        android:id="@+id/label"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="TextView"
        android:textSize="30sp"
        app:layout_constraintBottom_toTopOf="@+id/result"
        app:layout_constraintEnd_toEndOf="@+id/result"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toStartOf="@+id/result"
        app:layout_constraintTop_toTopOf="parent" />

    <TextView
        android:id="@+id/textView2"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="赤ワイン品質予測"
        android:textSize="30sp"
        app:layout_constraintBottom_toTopOf="@+id/label"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />
</androidx.constraintlayout.widget.ConstraintLayout>

推論

モデルをロードして、テンソルを入れて推論

MainActivity.kt
class MainActivity : AppCompatActivity() {
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        //ワインのテスト用データ
        val inputArray = floatArrayOf(7.1f, 0.46f, 0.2f, 1.9f, 0.077f, 28f, 54f, 0.9956f, 3.37f, 0.64f, 10.4f)
        //テンソルの生成: 引数(floatArray, テンソルのサイズ)
        val inputTensor = Tensor.fromBlob(inputArray, longArrayOf(1,11))
        //モデルのロード
        val module = Module.load(assetFilePath(this, "wineModel.pt"))
        //推論
        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
        val scores = outputTensor.dataAsFloatArray
        //結果の表示
        result.text ="予測値: ${scores[0]}"
        label.text = "正解ラベル:6"
    }

    //assetフォルダからパスを取得する関数
    fun assetFilePath(context: Context, assetName: String): String {
        val file = File(context.filesDir, assetName)
        if (file.exists() && file.length() > 0) {
            return file.absolutePath
        }
        context.assets.open(assetName).use { inputStream ->
            FileOutputStream(file).use { outputStream ->
                val buffer = ByteArray(4 * 1024)
                var read: Int
                while (inputStream.read(buffer).also { read = it } != -1) {
                    outputStream.write(buffer, 0, read)
                }
                outputStream.flush()
            }
            return file.absolutePath
        }
    }
}

完成!!
ここまで来て実行すると冒頭の画面が出てくるはず。

おわり

画像系はチュートリアルでもあるけど、普通の線形とかはあまり載ってなかったのでこの記事を書いてみた。
モデルの精度がイマイチだったのが引っかかる所だけど、一応線形モデルを動かすことができた。
今度は分類をやってみようかな。

今回のコードはこちら
pythonコード
android studio コード

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1