本稿ではPyTorchを利用したCIFAR-10の画像分類を行います。
公式のチュートリアルに沿ってコメントを添えつつ追っていきます。
尚、Pythonと機械学習は超初心者です。
CIFAR-10とは?
機械学習界隈で広く利用されている10ラベルの画像データセットです。
airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck
の10ラベルが用意されています。
環境
- macOS Catalina
- Python 3.7.2
- pip 19.1.1
PyTorchのインストール
公式サイトで各環境に合わせてインストールコマンドを発行してくれます。
私はmacOSなので次を実行してインストールします。
pip install torch torchvision
CNNを実装
必要なライブラリをインポートする
# NumPy、Matplotlib、PyTorchをインポートする
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
CIFAR-10をダウンロードする
# ToTensor:画像のグレースケール化(RGBの0~255を0~1の範囲に正規化)、Normalize:Z値化(RGBの平均と標準偏差を0.5で決め打ちして正規化)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# トレーニングデータをダウンロード
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
# テストデータをダウンロード
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2)
CIFAR-10を確認してみる
データを確認する
# 学習用データセット:縦横32ピクセルのRGBの画像が50000枚
print(trainset.data.shape)
(50000, 32, 32, 3)
# テスト用データセット:縦横32ピクセルのRGBの画像が10000枚
print(testset.data.shape)
(10000, 32, 32, 3)
# クラス一覧を確認する
print(trainset.classes)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# classesはよく利用するので別途保持しておく
classes = trainset.classes
公式ドキュメントではairplaneがplane、automobileがcarで再定義されていたけど何故だろう?
画像を表示する
# ダウンロードした画像を表示してみる
def imshow(img):
# 非正規化する
img = img / 2 + 0.5
# torch.Tensor型からnumpy.ndarray型に変換する
print(type(img)) # <class 'torch.Tensor'>
npimg = img.numpy()
print(type(npimg))
# 形状を(RGB、縦、横)から(縦、横、RGB)に変換する
print(npimg.shape)
npimg = np.transpose(npimg, (1, 2, 0))
print(npimg.shape)
# 画像を表示する
plt.imshow(npimg)
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
ネットワークを実装する
# CNNを実装する
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
initで各層を定義しておき、forwardで繋げて利用する。
損失関数・オプティマイザを定義する
# 交差エントロピー
criterion = nn.CrossEntropyLoss()
# 確率的勾配降下法
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
トレーニングする
# トレーニングする
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
# 誤差逆伝播
loss.backward()
optimizer.step()
train_loss = loss.item()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
[1, 2000] loss: 2.164
[1, 4000] loss: 1.863
[1, 6000] loss: 1.683
[1, 8000] loss: 1.603
[1, 10000] loss: 1.525
[1, 12000] loss: 1.470
[2, 2000] loss: 1.415
[2, 4000] loss: 1.369
[2, 6000] loss: 1.363
[2, 8000] loss: 1.333
[2, 10000] loss: 1.314
[2, 12000] loss: 1.317
Finished Training
2000ミニバッチ毎のLossの平均値をログに出力している。
モデルを保存する
# モデルを保存する
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
カレントディレクトリにpth(PyTorch)という拡張子でモデルを保存する。
モデルを使ってみる
# テストデータを読み込んで、画像と正しいラベルを表示する
dataiter = iter(testloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
# 保存したモデルを読み込んで、予測する
net = Net()
net.load_state_dict(torch.load(PATH))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
GroundTruth: truck cat airplane ship
Predicted: truck horse airplane ship
cat以外は正しく予測できていることが分かります。
print(outputs)
value, predicted = torch.max(outputs, 1)
print(value)
print(predicted)
tensor([[ 0.7114, -2.2724, 0.1225, 0.9470, 2.1940, 1.8655, -2.6655, 4.1646,
-1.1001, -1.6991],
[-2.2453, -4.1017, 1.8291, 3.2079, 1.1242, 3.6712, 1.0010, 1.0489,
-3.2010, -1.9476],
[-3.0669, -3.8900, 0.9312, 3.5649, 2.7791, 1.5095, 2.1216, 1.5274,
-4.3077, -2.2234],
[-2.0948, -3.4640, 2.4833, 2.6210, 4.0590, 1.8350, 0.4924, 0.7212,
-3.5043, -2.4212]], grad_fn=<AddmmBackward>)
tensor([4.1646, 3.6712, 3.5649, 4.0590], grad_fn=<MaxBackward0>)
tensor([7, 5, 3, 4])
torch.maxはoutputsの最大値を返してくれる。
モデルをテストする
correct = 0
total = 0
# 勾配を記憶せず(学習せずに)に計算を行う
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
Accuracy of the network on the 10000 test images: 60 %
10000枚のテストデータに対する正答率は60%であることが分かります。
以下はPython初心者の個人的メモ。
いまいち**(predicted == labels).sum().item()**この書き方が分からなかったので、ログに出して確認してみる。
print(type((predicted == labels)))
print((predicted == labels).dtype)
print(type((predicted == labels).sum()))
print((predicted == labels).sum())
print((predicted == labels).sum().item())
# <class 'torch.Tensor'>
# torch.bool
# <class 'torch.Tensor'>
# tensor(2)
# 2
なるほど。配列の各要素に対して比較し、torch.Tensorに実装されているsum()を使い、trueの合計値を算出。その後torch.Tensorに実装されているitem()を使い合計値をint型の数値としてしている。
numpyで確認するともう少し分かりやすかった。
# numpyで試してみる
a = np.array([1, 2, 3, 4, 5])
b = np.array([1, 2, 0, 4, 5])
print(type((a == b)))
print((a == b))
print((a == b).sum())
print(type((a == b).sum()))
print((a == b).sum().item())
print(type((a == b).sum().item()))
# <class 'numpy.ndarray'>
# [ True True False True True]
# 4
# <class 'numpy.int64'>
# 4
# <class 'int'>
公式を見ると、ndarrayとほぼ同じAPIを使えるので、**sum()やitem()**が使えるんですね。納得。
各ラベル毎の正答率を見てみる
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of airplane : 72 %
Accuracy of automobile : 66 %
Accuracy of bird : 38 %
Accuracy of cat : 58 %
Accuracy of deer : 60 %
Accuracy of dog : 29 %
Accuracy of frog : 73 %
Accuracy of horse : 60 %
Accuracy of ship : 69 %
Accuracy of truck : 73 %
チュートリアルだとこんなものなのかな?