2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Pytorch 学習済みモデルで識別[MNIST]

Last updated at Posted at 2020-09-03

##初めに

  • PytorchでCNNを構築し,MNISTデータセットを学習した。
  • 学習したモデルを読み込んで使うプログラムを組んだ。(意外と記事少なかった)
  • モデルの作成は以前記事にしたのでよかったら是非。( こちら 参照)

<対象>

  • 機械学習初心者(細かい内容についての解説はしません)
  • PyTorch触り始めた方
  • アバウトな解説でも耐えられる方
    <非対象>
  • Pytorch詳しい方
  • 精度向上したい方

[環境]
Python 3.6.9
torch 1.6.0
numpy 1.16.4
Pillow 6.2.0

##モデルの保存

PATH = "./my_mnist_model.pt"
torch.save(net.state_dict(), PATH)

学習済みモデルを保存する。
torch.save()の引数をnet.state_dict()とすることによりネットワーク構造や各レイヤの引数を省いて保存する。これにより保存したモデルの容量を削減することができるらしい。
逆に言うとモデルをロードする側はネットワーク構造を記述する必要がある。

##モデルを利用する

import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision
from torchvision import datasets, transforms
from PIL import Image, ImageOps
from torchsummary import summary

# モデルの定義
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv1 = nn.Conv2d(1,32,3,1)
        self.conv2 = nn.Conv2d(32,64,3,1)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(12*12*64,128)
        self.fc2 = nn.Linear(128,10)

    def forward(self,x):
        x = self.conv1(x)
        x = f.relu(x)
        x = self.conv2(x)
        x = f.relu(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = x.view(-1,12*12*64)
        x = self.fc1(x)
        x = f.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        return f.log_softmax(x, dim=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = 0
model = MyNet().to(device)
print(device)
print(model)
print(summary(model, (1, 28, 28)))

# 学習モデルをロードする
model.load_state_dict(torch.load("my_mnist_model.pt", map_location=lambda storage, loc: storage))
model = model.eval()

# 画像ファイルを読み込む(黒背景, 白文字を想定)
PATH = "mnist/3.jpg"
image = Image.open(PATH)
image = ImageOps.invert(image)
image = image.convert('L').resize((28,28))
# データの前処理の定義(モデル生成の際と同じ平均値と標準偏差で正規化する)
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                                ])

# 元のモデルに合わせて次元を追加
image = transform(image).unsqueeze(0)

# 予測を実施
output = model(image.to(device))
_, prediction = torch.max(output, 1)
# 結果を出力
print("{} -> result = ".format(PATH) + str(prediction[0].item()))

##実行結果
結果①
2.jpg

mnist/2.jpg -> result = 2

結果②
6.jpg

mnist/6.jpg -> result = 6

##まとめ
なんかうまくいってる雰囲気。
PyTorchはOpenCVではなくPillowを使う場合が多いみたい。
モデル自体の精度はあまり高くなかった。過学習してるのかもしれない。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?