LoginSignup
5
8

More than 3 years have passed since last update.

アンドロイドで画面に数字を書いて画像認識するアプリを作る(PyTorch Mobile)[CNNネットワーク作成編]

Last updated at Posted at 2020-01-25

今回作成するアプリ

画面に書いた数字を認識する画像認識アプリをPytorch Mobileとkotlinで作る。
画像認識用のモデルとアンドロイドの機能を1から全部作る。
CNNネットワーク作成編(Python)アンドロイド実装編(kotlin)の全2回に分けます。

Python環境がないアンドロイドエンジニアの方やモデル作成がめんどいって方はアンドロイドで画面に書いた数字を判別する画像認識アプリを作る(PyTorch Mobile)[アンドロイド実装編]へ行って実装編で学習済みモデルをダウンロードして進めてください。

Githubに今回のpythonコード挙げてます
Github: https://github.com/SY-BETA/CNN_PyTorch

これ↓

作成の流れ

1.MNISTをダウンロードする (※チャネル数を3チャネルに直す必要あり)
2. 簡単なCNNモデルをpython(PyTorch)で作成
3. モデルを学習させる
4. モデルを保存
5. アンドロイドで絵を描ける機能を実装
6. アンドロイドにモデルを実装してforwardプロパゲーションする

この回でやること

1~4までやる。
python使ってモデルの保存までやる。今回使用するライブラリはPyTorch 実行環境はjupyter notebook
MNISTのデータセットをダウンロードしシンプルなCNNモデルを作成し学習させる。

MNISTダウンロード

みんな大好き手書き数字データセットMNISTをtorchvisionを使ってダウンロードする

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
        transforms.ToTensor()])
train = torchvision.datasets.MNIST(
    root="data/train", train=True, transform=transform, target_transform=None, download=True)
test = torchvision.datasets.MNIST(
    root="data/test", train=False, transform=transform, target_transform=None, download=True)

MNISTを見てみる

どんなデータセットか見てみる

from matplotlib import pyplot as plt
import numpy as np

print(train.data.size())
print(test.data.size())
img = train.data[0].numpy()
plt.imshow(img, cmap='gray')
print('Label:', train.targets[0])

実行結果
キャプチaaaaャ.PNG

グレースケールからRGBに変更する

MNISTのカラーチャネル数を1から3にする。

なんでそんな計算量増える無駄ことをわざわざするのか? -> アンドロイドで画像を扱うときにbitmap形式で扱う、それをpytorch mobileでテンソルに変換するときにチャネル数3のテンソルにしか変換できなかった...
(ちょっとやり方が分からなかったので知っている方いらっしゃれば教えてください...)
なのでモデルを学習させるデータをRGBにして学習させる。

train_data_resized = train.data.numpy()  #torchテンソルからnumpyに
test_data_resized = test.data.numpy()

train_data_resized = torch.FloatTensor(np.stack((train_data_resized,)*3, axis=1))  #RGBに変換
test_data_resized =  torch.FloatTensor(np.stack((test_data_resized,)*3, axis=1))
print(train_data_resized.size())

これでデータセットのサイズがtorch.Size([60000, 28, 28])からtorch.Size([60000, 3, 28, 28])になった。

データセットを自作する

カスタムデータセットクラスを作る

今回はチャネル数の関係でMNISTのデータセットはそのまま使用できないので、pytorchのDatasetを継承してカスタムデータセットを作る。
また、画像の前処理である標準化するクラスもここで作る。

import torch.utils.data as data

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

#画像の前処理
class ImgTransform():
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.ToTensor(),  # テンソル変換
            transforms.Normalize(mean, std)  # 標準化
        ])

    def __call__(self, img):
        return self.transform(img)

#Datasetクラスを継承
class _3ChannelMnistDataset(data.Dataset):
    def __init__(self, img_data, target, transform):
        #[データ数,高さ,横,チャネル数]に
        self.data = img_data.numpy().transpose((0, 2, 3, 1)) /255
        self.target = target
        self.img_transform = transform #画像前処理クラスのインスタンス

    def __len__(self):
        #画像の枚数を返す
        return len(self.data)

    def __getitem__(self, index):
        #画像の前処理(標準化)したデータを返す
        img_transformed = self.img_transform(self.data[index])
        return img_transformed, self.target[index]

なおmeanstdはVGG16とかでも標準化によく使ういつもの値。アンドロイドでテンソルに変換するときに必ず標準化する、その時の値がこれ。
値がわからなかったら android studio でpytroch mobileのImageUtilsを確認してもよい。
aaaaキャプチャ.PNG

上記で作成したクラスを使ってデータセット作成

train_dataset = _3ChannelMnistDataset(train_data_resized, train.targets, transform=ImgTransform())
test_dataset = _3ChannelMnistDataset(test_data_resized, test.targets, transform=ImgTransform())

# データセットをテストしてみる
index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])
print(train_dataset.__getitem__(index)[0][1]) #ちゃんと標準化されていることがわかる

データローダー作成

作ったデータセットでカスタムデータローダーを作る。バッチサイズは適当に100

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False)

CNNネットワークを作成

畳み込み1層、全結合3層のシンプルなネットワークを適当に作成。(学習に時間かかるのも嫌だし)

from torch import nn
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(3)
        self.conv = nn.Conv2d(3, 10, kernel_size=4)
        self.fc1 = nn.Linear(640, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size()[0], -1) #行列を線形処理できるようにベクトルに(view(高さ、横))
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

model = Model()
print(model)

こんなネットワーク
キャプfadsfcvaチャ.PNG

ネットワークを学習させる

訓練modeと推論modeの関数を作る

import tqdm
from torch import optim

# 推論モード
def eval_net(net, data_loader, device="cpu"): #GPUある人はgpuに
    #推論モードに
    net.eval()
    ypreds = [] #予測したラベル格納変数
    for x, y in (data_loader):
        # toメソッドでデバイスに転送
        x = x.to(device)
        y = [y.to(device)]
        # 確率が最大のクラスを予測
        # forwardプロパゲーション
        with torch.no_grad():
            _, y_pred = net(x).max(1)
            ypreds.append(y_pred)
            # ミニバッチごとの予測を一つのテンソルに
            y = torch.cat(y)
            ypreds = torch.cat(ypreds)
            # 予測値を計算(正解=予測の要素の和)
            acc = (y == ypreds).float().sum()/len(y)
            return acc.item()


# 訓練モード
def train_net(net, train_loader, test_loader,optimizer_cls=optim.Adam, 
              loss_fn=nn.CrossEntropyLoss(),n_iter=3, device="cpu"):
    train_losses = []
    train_acc = []
    eval_acc = []
    optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):  #4回回す
        runnig_loss = 0.0
        # 訓練モードに
        net.train()
        n = 0
        n_acc = 0

        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
                                     total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            output = net(xx)

            loss = loss_fn(output, yy)
            optimizer.zero_grad()   #optimizerの初期化
            loss.backward()   #損失関数(クロスエントロピー誤差)からバックプロパゲーション
            optimizer.step()

            runnig_loss += loss.item()
            n += len(xx)
            _, y_pred = output.max(1)
            n_acc += (yy == y_pred).float().sum().item()

        train_losses.append(runnig_loss/i)
        # 訓練データの予測精度
        train_acc.append(n_acc / n)
        # 検証データの予測精度
        eval_acc.append(eval_net(net, test_loader, device))

        # このepochでの結果を表示
        print("epoch:",epoch, "train_loss:",train_losses[-1], "train_acc:",train_acc[-1],
              "eval_acc:",eval_acc[-1], flush=True)

まずは学習なしで推論してみる

eval_net(model, test_loader)

ネットワークのランダムパラメータのseed値を固定していないので再現性はなくランダムに変わるが、自分の環境では学習前のスコアは0.0799999982って感じになった。

学習させる

先ほど作成した関数を使って学習

train_net(model, train_loader, test_loader)

最終的に予測精度が0.98000001907くらいになった。えっ、精度高すぎね。精度良すぎてあってるか不安になる...

実際に1つ推論してみる

学習させたモデルにデータを1つ入れてラベルを予測してみる。

data = train_dataset.__getitem__(0)[0].reshape(1, 3, 28, 28) #リサイズ(データローダーのサイズに注意)
print("ラベル",train_dataset.__getitem__(0)[1].data)
model.eval()
output = model(data)
print(output.size())
output

実行結果
キafdfafdaャプチャ.PNG
しっかりインデックスが5のスコアが一番高くなっていて予測できていることがわかる。

やっと、モデルの作成と学習が終了!!

モデルを保存する

アンドロイドで使うためにモデルを保存する

# モデルの保存
model.eval()
#サンプル入力サイズ
example = torch.rand(1, 3, 28, 28)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("./CNNModel.pt")
print(model)

おわり

とりあえずこれで[ネットワーク作成編] 終了!! 次は作ったモデルをアンドロイドに実装していく。
PyTorch Mobileでテンソルに変換するときにRGBのテンソルになり、グレースケールにできなかったので、MNISTをわざわざRGBに変換したりして、結構面倒な処理が多くなった。
その影響でMNISTのデータセットがそのまま使えず自作のデータセット、データローダーを使わなきゃいけなくなった。まあ、グレースケールとか商用レベルではほとんど使えないんだろうけど。
あと、適当に作ったCNNネットワークだったが意外と精度高くなっておどろいた、さすがはCNN
一応Githubあげてます。

今回のコード Github: https://github.com/SY-BETA/CNN_PyTorch

今回の作成した学習済みモデル(.py) : https://github.com/SY-BETA/CNN_PyTorch/blob/master/CNNModel.pt

それではアンドロイド実装編へレッツゴー
アンドロイドで画面に書いた数字を判別する画像認識アプリを作る(PyTorch Mobile)[アンドロイド実装編]

5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8