##初めに
- PytorchでCNNを構築し,MNISTデータセットを学習した。
- 学習したモデルを読み込んで使うプログラムを組んだ。(意外と記事少なかった)
- モデルの作成は以前記事にしたのでよかったら是非。( こちら 参照)
<対象>
- 機械学習初心者(細かい内容についての解説はしません)
- PyTorch触り始めた方
- アバウトな解説でも耐えられる方
<非対象> - Pytorch詳しい方
- 精度向上したい方
[環境]
Python 3.6.9
torch 1.6.0
numpy 1.16.4
Pillow 6.2.0
##モデルの保存
PATH = "./my_mnist_model.pt"
torch.save(net.state_dict(), PATH)
学習済みモデルを保存する。
torch.save()
の引数をnet.state_dict()
とすることによりネットワーク構造や各レイヤの引数を省いて保存する。これにより保存したモデルの容量を削減することができるらしい。
逆に言うとモデルをロードする側はネットワーク構造を記述する必要がある。
##モデルを利用する
import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision
from torchvision import datasets, transforms
from PIL import Image, ImageOps
from torchsummary import summary
# モデルの定義
class MyNet(nn.Module):
def __init__(self):
super(MyNet,self).__init__()
self.conv1 = nn.Conv2d(1,32,3,1)
self.conv2 = nn.Conv2d(32,64,3,1)
self.pool = nn.MaxPool2d(2,2)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(12*12*64,128)
self.fc2 = nn.Linear(128,10)
def forward(self,x):
x = self.conv1(x)
x = f.relu(x)
x = self.conv2(x)
x = f.relu(x)
x = self.pool(x)
x = self.dropout1(x)
x = x.view(-1,12*12*64)
x = self.fc1(x)
x = f.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return f.log_softmax(x, dim=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = 0
model = MyNet().to(device)
print(device)
print(model)
print(summary(model, (1, 28, 28)))
# 学習モデルをロードする
model.load_state_dict(torch.load("my_mnist_model.pt", map_location=lambda storage, loc: storage))
model = model.eval()
# 画像ファイルを読み込む(黒背景, 白文字を想定)
PATH = "mnist/3.jpg"
image = Image.open(PATH)
image = ImageOps.invert(image)
image = image.convert('L').resize((28,28))
# データの前処理の定義(モデル生成の際と同じ平均値と標準偏差で正規化する)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 元のモデルに合わせて次元を追加
image = transform(image).unsqueeze(0)
# 予測を実施
output = model(image.to(device))
_, prediction = torch.max(output, 1)
# 結果を出力
print("{} -> result = ".format(PATH) + str(prediction[0].item()))
mnist/2.jpg -> result = 2
mnist/6.jpg -> result = 6
##まとめ
なんかうまくいってる雰囲気。
PyTorchはOpenCVではなくPillowを使う場合が多いみたい。
モデル自体の精度はあまり高くなかった。過学習してるのかもしれない。