はじめてのPyTorchで、RasNetみたいなAIモデルをゼロからつくってみよう
最近、機械学習の勉強をはじめまして、PyTorch提供するモデルであるResNet等の画像分類AIをtorchvision
経由で利用していたのですが(下記記事)、 ゼロからモデルを組み上げる場合どうすれば良いのだろう と思い、下記書籍で勉強しはじめました。今回の記事は、書籍の抜粋に近く恐縮なのですが、 ゼロからPython+PyTorchで画像分類をするモデルを定義し、その内容を整理するもの としました。
定義済みのモデルをまずは動かしてみたい、という方は下記をご参照ください。
参考にした書籍
書籍の方が正確で読みやすいと思いますので、ご興味あれば是非ご購読ください。
PyTorch実践入門 ~ ディープラーニングの基礎から実装へ
本記事により実現できる物体検出モデル
本記事で扱うモデルを実装することにより、 1000epochsのイテレーションの結果、学習用データセットに対しては84%、検証用データセットに対しても72%程度の正答率を実現できる物体検出器を構築 することができます。
2024-12-16 23:52:07.780146 epoch 980, Training loss 0.8238151277727483
2024-12-16 23:53:03.873066 epoch 985, Training loss 0.8236197303323185
2024-12-16 23:53:59.997969 epoch 990, Training loss 0.8230565896119608
2024-12-16 23:54:56.139902 epoch 995, Training loss 0.8226507447869577
2024-12-16 23:55:50.857574 epoch 1000, Training loss 0.8289129080065071
Accuracy train : 0.84
Accuracy val : 0.72
ソースコードと、その実行方法と結果
早速ですが、 ソースコード全文は以下となります。 実行にはPython仮想環境を準備し、必要なパッケージをあらかじめインストールしておく必要があります。
ソースコード全文
下記のソースコードが、 AIのモデルを定義し、モデルを織り成すハイパーパラメータを定義し、そのモデルを画像分類セットであるCIFER10で学習し、学習結果をCIFER10の検証セットで評価し、学習結果のモデルを保存するまで のソースコード全文です。実行する際は、下記のソースコードに加えて import
from
節で利用しているパッケージを仮想環境へインストールする必要があります。これについては以降の ソースコードの実行方法 をご参照ください。まずは下記ソースコードを 「sample_conv.py」 という名前で保存しておきましょう。
CIFIR10は、32px x 32pxの非常に小さく低解像度な画像を、10種類のラベルデータに分類するデータセットです。 今回のように小さな規模のモデルを体験する のに適しています。 実用に足る解像度のモデルを開発する際には、幅広いバリエーションを扱う場合は1000種類の分類をもつ 224px x 224px の画像データセット「Imagenet」の方が向いています。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchinfo import summary
from matplotlib import pyplot as plt
import datetime
########
# Hyper Parameters
########
cifer10_class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
data_path = "./data-universioned/"
weight_path = './run/cifer10_sample.pt'
batch_size = 128
n_blocks = 10
learning_rate = 1e-2
l2_lambda = 1e-2
n_epochs = 1000
# Set Accelerator
if torch.cuda.is_available():
device = 'cuda'
else:
if torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
print("====================")
print(".to(device='{0}')".format(device))
print("====================")
########
# PIL images
########
cifer10_train = datasets.CIFAR10(data_path, train=True, download=True)
cifer10_val = datasets.CIFAR10(data_path, train=False, download=True)
# Tensor images
tensor_cifer10_train = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())
tensor_cifer10_val = datasets.CIFAR10(data_path, train=False, download=False, transform=transforms.ToTensor())
# calc mean and std
# imgs = torch.stack([img_t for img_t, _ in tensor_cifer10_train], dim=3)
# imgs_mean = imgs.view(3, -1).mean(dim=1)
# print("datasets mean:", imgs_mean)
# imgs_std = imgs.view(3, -1).std(dim=1)
# print("datasets std:", imgs_std)
### mean: tensor([0.4914, 0.4822, 0.4465])
### std: tensor([0.2470, 0.2435, 0.2616])
# Normalized images
transformed_cifer10_train = datasets.CIFAR10(data_path,
train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
transformed_cifer10_val = datasets.CIFAR10(data_path,
train=False, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
########
# DataLoader
########
train_loader = torch.utils.data.DataLoader(transformed_cifer10_train, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(transformed_cifer10_val, batch_size=batch_size, shuffle=True)
########
# define Models (Sub-Net and Net)
########
class ResBlock(nn.Module):
def __init__(self, n_chans):
super(ResBlock, self).__init__()
# create layers
self.n_chans1 = n_chans
self.conv1 = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
self.conv3 = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
self.conv_skip = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
self.batch_norm1 = nn.BatchNorm2d(num_features=n_chans)
self.batch_norm2 = nn.BatchNorm2d(num_features=n_chans)
self.batch_norm3 = nn.BatchNorm2d(num_features=n_chans)
# init weight and bias
torch.nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
torch.nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
torch.nn.init.kaiming_normal_(self.conv3.weight, nonlinearity='relu')
torch.nn.init.kaiming_normal_(self.conv_skip.weight, nonlinearity='relu')
torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm1.bias)
torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm2 .bias)
torch.nn.init.constant_(self.batch_norm3.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm3.bias)
def forward(self, x):
# skip path
input_1 = x.clone()
# main path
out = self.conv1(x)
out = self.batch_norm1(out)
out = torch.relu(out)
out = self.conv2(x)
out = self.batch_norm2(out)
out = torch.relu(out)
out = self.conv3(x)
out = self.batch_norm3(out)
# convine
out += input_1
out = torch.relu(out)
return out
class Net(nn.Module):
def __init__(self, n_chans1=32, n_blocks=10, n_out=10):
super().__init__()
self.n_chans1 = n_chans1
self.n_out = n_out
self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
self.resblocks = nn.Sequential()
for idx_block in range(1, n_blocks + 1):
self.resblocks.add_module("idx{}".format(idx_block), ResBlock(n_chans=n_chans1))
self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
self.fc2 = nn.Linear(32, n_out)
def forward(self, x):
out = F.max_pool2d(torch.relu(self.conv1(x)), 2)
out = self.resblocks(out)
out = F.max_pool2d(out, 2)
out = out.view(-1, 8 * 8 * self.n_chans1)
out = torch.relu(self.fc1(out))
out = self.fc2(out)
return out
########
# def Training
########
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
for epoch in range(1, n_epochs+1):
loss_train = 0.0
if epoch == 1:
print('{} epoch 0, Starting ...'.format(
datetime.datetime.now()))
for imgs, labels in train_loader:
imgs = imgs.to(device=device)
labels = labels.to(device=device)
outputs = model(imgs)
loss = loss_fn(outputs, labels)
l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
loss = loss + l2_lambda * l2_norm
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_train += loss.item()
if epoch == 1 or epoch % 5 == 0:
print('{} epoch {}, Training loss {}'.format(
datetime.datetime.now(),
epoch, loss_train / len(train_loader)
))
########
# def Validation
########
def validate(device, model, train_loader, val_loader):
for name, loader in [("train", train_loader), ("val", val_loader)]:
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs = imgs.to(device=device)
labels = labels.to(device=device)
outputs = model(imgs)
_, predicted = torch.max(outputs, dim=1)
total += labels.shape[0]
correct += int((predicted == labels).sum())
print("Accuracy {} : {:.2f}".format(name, correct/total))
########
# exec learinig and validation
########
# learning params and define model
model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
summary(model, input_size=(batch_size,
tensor_cifer10_train[0][0].shape[0],
tensor_cifer10_train[0][0].shape[1],
tensor_cifer10_train[0][0].shape[2]),
device=device,
depth=3)
model = model.to(device=device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
# training
model.train()
training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader)
torch.save(model.state_dict(), weight_path)
# validation
loaded_model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
loaded_model = loaded_model.to(device=device)
loaded_model.load_state_dict(torch.load(weight_path, ))
loaded_model.eval()
validate(device, loaded_model, train_loader, val_loader)
print("quit()")
quit()
ソースコードの実行方法
上記のスクリプトを 「sample_conv.py」 という名前で保存し、下記の手順によりスクリプトを実行します。まず、Pythonの仮想環境をAnacondaのconda create
コマンドにより作成し、次に、必要なパッケージをpip
コマンドによりインストールし、最後にpython
コマンドによりスクリプトを実行します。
# Python3.12ベースの仮想環境をAnacondaにより作成する
$ conda create -n learning_conv python=3.12
# Proceed ([y]/n)? y
# To activate this environment, use
#
# $ conda activate learning_conv
# ...
# 仮想環境を有効化する
$ conda activate learning_conv
(learning_conv) $ python --version
# Python 3.12.8
# 必要なパッケージをインストールする
(learning_conv) $ pip install torch torchvision torchaudio torchinfo
(learning_conv) $ pip install matplotlib datetime
# スクリプトを実行する
(learning_conv) $ python sample_conv.py
####
# 実行結果が表示される
####
実行結果
スクリプトの実行に成功すると、下記のような結果を得ることができます。まず、どのデバイスの上でモデルを実行するか(mps
、cuda
、cpu
)の選択が行われ、続いて学習と評価に利用するデータセットのダウンロードが実行されます。その後、定義したモデルのサマリーが出力され、モデルの学習が開始されます。最後に、学習済みのモデルによる推論により、学習用データセットと検証用データセットに対してどの程度の正答率となるかを表示し終了となります。 今回の例では学習用データセットに対して84%は、検証用データセットに対しては72%の正答率でした。 モデルの組み方や学習方法によっては、調整したパラメータが学習用データセットに特化しすぎて、検証用データセットの正答率が低いオーバーフィッティングの状態となることがありますので、必ず、学習用データに含まれていないデータセットで検証する必要があります。
====================
.to(device='mps')
====================
Files already downloaded and verified
Files already downloaded and verified
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Net [128, 10] --
├─Conv2d: 1-1 [128, 32, 32, 32] 896
├─Sequential: 1-2 [128, 32, 16, 16] --
│ └─ResBlock: 2-1 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-1 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-2 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-3 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-4 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-5 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-6 [128, 32, 16, 16] 64
│ └─ResBlock: 2-2 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-7 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-8 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-9 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-10 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-11 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-12 [128, 32, 16, 16] 64
│ └─ResBlock: 2-3 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-13 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-14 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-15 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-16 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-17 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-18 [128, 32, 16, 16] 64
│ └─ResBlock: 2-4 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-19 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-20 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-21 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-22 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-23 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-24 [128, 32, 16, 16] 64
│ └─ResBlock: 2-5 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-25 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-26 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-27 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-28 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-29 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-30 [128, 32, 16, 16] 64
│ └─ResBlock: 2-6 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-31 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-32 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-33 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-34 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-35 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-36 [128, 32, 16, 16] 64
│ └─ResBlock: 2-7 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-37 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-38 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-39 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-40 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-41 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-42 [128, 32, 16, 16] 64
│ └─ResBlock: 2-8 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-43 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-44 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-45 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-46 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-47 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-48 [128, 32, 16, 16] 64
│ └─ResBlock: 2-9 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-49 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-50 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-51 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-52 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-53 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-54 [128, 32, 16, 16] 64
│ └─ResBlock: 2-10 [128, 32, 16, 16] 1,024
│ │ └─Conv2d: 3-55 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-56 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-57 [128, 32, 16, 16] 9,216
│ │ └─BatchNorm2d: 3-58 [128, 32, 16, 16] 64
│ │ └─Conv2d: 3-59 [128, 32, 16, 16] 1,024
│ │ └─BatchNorm2d: 3-60 [128, 32, 16, 16] 64
├─Linear: 1-3 [128, 32] 65,568
├─Linear: 1-4 [128, 10] 330
==========================================================================================
Total params: 191,594
Trainable params: 191,594
Non-trainable params: 0
Total mult-adds (G): 3.82
==========================================================================================
Input size (MB): 1.57
Forward/backward pass size (MB): 536.91
Params size (MB): 0.73
Estimated Total Size (MB): 539.21
==========================================================================================
2024-12-16 20:51:35.152977 epoch 0, Starting ...
2024-12-16 20:51:45.996456 epoch 1, Training loss 27.925368345606966
2024-12-16 20:52:29.703423 epoch 5, Training loss 15.228352090586787
2024-12-16 20:53:25.581729 epoch 10, Training loss 7.535380253706442
2024-12-16 20:54:21.600999 epoch 15, Training loss 4.021200152614233
2024-12-16 20:55:17.960491 epoch 20, Training loss 2.3994958912929913
...
2024-12-16 23:52:07.780146 epoch 980, Training loss 0.8238151277727483
2024-12-16 23:53:03.873066 epoch 985, Training loss 0.8236197303323185
2024-12-16 23:53:59.997969 epoch 990, Training loss 0.8230565896119608
2024-12-16 23:54:56.139902 epoch 995, Training loss 0.8226507447869577
2024-12-16 23:55:50.857574 epoch 1000, Training loss 0.8289129080065071
Accuracy train : 0.84
Accuracy val : 0.72
ソースコードの抜粋と解説
それではモデルを構成、学習、検証するためのソースコードの各部を見ていきましょう。
モデルを可視化する
まず、定義したモデル(nn.Module
を継承したクラス)はtorchinfo
のsummary
メソッドにより可視化することができます。可視化には、 モデル に加えて、 バッチサイズ (1回の推論の入力データの組:今回は128枚を同時に処理します、メモリの容量などに合わせて調整ください)、 入力データの形状 の指定が必要です。今回はCIFER-10という RGB3ch(tensor_cifer10_train[0][0].shape[0]) x 32pix(tensor_cifer10_train[0][0].shape[1]) x 32pix(tensor_cifer10_train[0][0].shape[2]) のデータを処理します。最後に、 モデルをどの解像度まで表現するかを「depth」 で指定しましょう。
この可視化により、パラメータの個数、メモリの消費量、モデルを構成する各層の入出力データの形状を確認することができます。
### 抜粋
# learning params and define model
model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
summary(model, input_size=(batch_size,
tensor_cifer10_train[0][0].shape[0],
tensor_cifer10_train[0][0].shape[1],
tensor_cifer10_train[0][0].shape[2]),
device=device,
depth=3)
これにより、データ量などが下記のように表示されます。
### 抜粋
==========================================================================================
Total params: 191,594
Trainable params: 191,594
Non-trainable params: 0
Total mult-adds (G): 3.82
==========================================================================================
Input size (MB): 1.57
Forward/backward pass size (MB): 536.91
Params size (MB): 0.73
Estimated Total Size (MB): 539.21
==========================================================================================
モデルを実行するデバイスを選択する
近年のAIモデルは非常に巨大(多層)となっており、調整するパラメータも大量です。そのため、Modelと入力データをCUDA(NVIDIA)やMPS(Apple)等のアクセラレータに載せて演算することが一般的です。 下記のソースコードにより、アクセラレータを利用できるか確認し、.to(device=*)
のメソッドに指定するデバイス名を決定しましょう。 なお、NVIDIAのGPU等はワークステーションに複数搭載できますので、より複雑なモデルを実行する際にはcuda:0
やcuda:1
など細かいスケジューリングをすることも可能です。
### 抜粋
if torch.cuda.is_available():
device = 'cuda'
else:
if torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
### 抜粋
model = model.to(device=device) # モデルをアクセラレータに載せる
### 抜粋
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
for epoch in range(1, n_epochs+1):
...
for imgs, labels in train_loader:
imgs = imgs.to(device=device) # 入力データをアクセラレータに載せる
labels = labels.to(device=device) # ラベルデータをアクセラレータに載せる
outputs = model(imgs) # アクセラレータ上で計算する
torchvisionのdatasetsとDataLoaderにより利用するデータを整備する
PyTorchにより定義されたAIのモデルを実行する際、 1回の推論で1個のデータのみを扱うことはなく、バッチと呼ばれる単位の複数データをモデルに与えることが一般的です。 今回の例では128個のデータを同時に処理しています。このようにデータセットから複数のデータを読み出すことを、PyTorchでは datasetsによりデータ一式を定義 し、 DataLoaderにより読み出すストリームを定義 します。
### 抜粋
transformed_cifer10_train = datasets.CIFAR10(data_path, ...
### 抜粋
train_loader = torch.utils.data.DataLoader(transformed_cifer10_train, batch_size=batch_size, ...
また、PyTorchのdatasetsには、データを読み出す際にデータを加工するtransformsという処理を指定することができます。transformsは 「Compose」でどのようにデータを加工するかのリストを定義 できます。今回はComposeの内容として 「ToTensor()」によるRGB256諧調のデータを0〜1のテンソルへ正規化する変換と、「Normalize()」による平均値と標準偏差を使ったメリハリのあるデータの生成 を定義しています。
### 抜粋
transformed_cifer10_train = datasets.CIFAR10(data_path,
train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
なお、Normalizeのパラメータは下記コードにより事前に生成しました。
### Tensor images
# tensor_cifer10_train = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())
# tensor_cifer10_val = datasets.CIFAR10(data_path, train=False, download=False, transform=transforms.ToTensor())
### calc mean and std
# imgs = torch.stack([img_t for img_t, _ in tensor_cifer10_train], dim=3)
# imgs_mean = imgs.view(3, -1).mean(dim=1)
# print("datasets mean:", imgs_mean)
# imgs_std = imgs.view(3, -1).std(dim=1)
# print("datasets std:", imgs_std)
### mean: tensor([0.4914, 0.4822, 0.4465])
### std: tensor([0.2470, 0.2435, 0.2616])
モデルをnn.Moduleにより定義する
PyTorchで扱うモデルは、nn.Module
を継承したクラスとして定義します。ここで定義するクラスでは、計算処理の流れをforward(self, ...)
で定義し(例:output = gx(fx(input)))、処理に利用するニューロンを__init__(self, ...)
で生成します。生成されたニューロンが持つ演算のパラメータは学習ごとに微調整され、学習が終わった際に適切なパラメータを保持した関数になります。
### 抜粋
class Net(nn.Module):
def __init__(self, n_chans1=32, n_blocks=10, n_out=10):
super().__init__()
self.n_chans1 = n_chans1
self.n_out = n_out
self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
self.resblocks = nn.Sequential()
for idx_block in range(1, n_blocks + 1):
self.resblocks.add_module("idx{}".format(idx_block), ResBlock(n_chans=n_chans1))
self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
self.fc2 = nn.Linear(32, n_out)
def forward(self, x):
out = F.max_pool2d(torch.relu(self.conv1(x)), 2)
out = self.resblocks(out)
out = F.max_pool2d(out, 2)
out = out.view(-1, 8 * 8 * self.n_chans1)
out = torch.relu(self.fc1(out))
out = self.fc2(out)
return out
また、ネットワークを定義する際は、nn.Module
で事前に定義したネットワークを再利用することができます。下記の定義では__init__
にて、事前にnn.Module
を定義したResBlock
クラスを生成し、nn.Sequential()
で生成した処理列にadd_module
で結合しています。これにより今回のネットワークは10層のResBlockから構成されます。
Net [128, 10] --
├─Conv2d: 1-1 [128, 32, 32, 32] 896
├─Sequential: 1-2 [128, 32, 16, 16] --
│ └─ResBlock: 2-1 [128, 32, 16, 16] --
│ └─ResBlock: 2-2 [128, 32, 16, 16] --
│ └─ResBlock: 2-3 [128, 32, 16, 16] --
│ └─ResBlock: 2-4 [128, 32, 16, 16] --
│ └─ResBlock: 2-5 [128, 32, 16, 16] --
│ └─ResBlock: 2-6 [128, 32, 16, 16] --
│ └─ResBlock: 2-7 [128, 32, 16, 16] --
│ └─ResBlock: 2-8 [128, 32, 16, 16] --
│ └─ResBlock: 2-9 [128, 32, 16, 16] --
│ └─ResBlock: 2-10 [128, 32, 16, 16] --
├─Linear: 1-3 [128, 32] 65,568
├─Linear: 1-4 [128, 10] 330
サブネットをnn.Moduleにより定義してネットに組み込む
上記で利用したネットワークResBlock
は下記のようにnn.Module
を継承したクラスにより定義できます。このクラスも呼び出し元クラスNet
の__init__
でインスタンス化された後に、学習フェーズでパラメータが調整されます。ここで定義するサブネットは RasNetらしく、畳み込み層(Conv2d)、バッチ正規化層(BatchNorm2d)、活性化層(ReLu)を3つ重ね、最後に入力を加算 します。
このようにサブネットワークを定義することは、メインのネットワークを定義する際にサブネットワークの繰り返しを活用できますので、 複雑なネットワークや深いネットワークの定義をする際に有用です。
### 抜粋
class ResBlock(nn.Module):
def __init__(self, n_chans):
super(ResBlock, self).__init__()
# create layers
self.n_chans1 = n_chans
self.conv1 = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
self.conv3 = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
self.conv_skip = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
self.batch_norm1 = nn.BatchNorm2d(num_features=n_chans)
self.batch_norm2 = nn.BatchNorm2d(num_features=n_chans)
self.batch_norm3 = nn.BatchNorm2d(num_features=n_chans)
# init weight and bias
torch.nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
torch.nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
torch.nn.init.kaiming_normal_(self.conv3.weight, nonlinearity='relu')
torch.nn.init.kaiming_normal_(self.conv_skip.weight, nonlinearity='relu')
torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm1.bias)
torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm2 .bias)
torch.nn.init.constant_(self.batch_norm3.weight, 0.5)
torch.nn.init.zeros_(self.batch_norm3.bias)
def forward(self, x):
# skip path
input_1 = x.clone()
# main path
out = self.conv1(x)
out = self.batch_norm1(out)
out = torch.relu(out)
out = self.conv2(x)
out = self.batch_norm2(out)
out = torch.relu(out)
out = self.conv3(x)
out = self.batch_norm3(out)
# convine
out += input_1
out = torch.relu(out)
return out
モデルの学習を実施する
上記の内容でネットワークを定義できましたので、学習を行いましょう。学習は下記のソースコードにより実現することができます。まず、ネットワークをmodel.train()
で学習モードに切り替えます。そして学習した後はパラメータを今後に活かすため、torch.save(model.state_dict, ...)
を利用してモデルのパラメータをファイルへ保存します(※ここで保存したパラメータは推論時にロードして利用しています)。
### 抜粋
model.train()
training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader)
torch.save(model.state_dict(), weight_path)
仮の回答を得て、正答との差分lossを算出する
学習ではまず、DataLoaderにより定義した train_loader
からバッチサイズ幅のデータを imgs
labels
を取り出し、これをモデルへ入力して output
を得ます。ここでモデルはパラメータ調整されていませんので output
と正答である labels
は一致しません。
### 抜粋
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
for epoch in range(1, n_epochs+1):
loss_train = 0.0
if epoch == 1:
...
for imgs, labels in train_loader:
imgs = imgs.to(device=device)
labels = labels.to(device=device)
outputs = model(imgs)
...
if epoch == 1 or epoch % 5 == 0:
...
そこで期待するモデルを実現するために outputs
と labels
の差分から算出した loss
を最小化するようにパラメータの調整を行います。lossの計算には損失関数として定義したnn.CrossEntropyLoss()
で算出し、パラメータはoptimizer
として選択したoptim.SGD(model.parameters(), ...)
により調整します。optimizerには、どのくらいの刻みで調整するかを決めるlr=learning_rate
を指定します。
各バッチ処理ごとにloss_fn
によりloss
を計算し、optimizerの保持する勾配をoptimizer.zero_grad()
により初期化、loss.backward()
により誤差の逆伝搬を実施し、その結果をoptimizer.step()
によりモデルのパラメータへ反映します。
### 抜粋
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
for epoch in range(1, n_epochs+1):
...
loss_train = 0.0
for imgs, labels in train_loader:
imgs = imgs.to(device=device)
labels = labels.to(device=device)
outputs = model(imgs)
loss = loss_fn(outputs, labels)
...
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_train += loss.item()
...
プログラムの下記箇所はL2正則化と呼ばれる最適化のテクニックを意味しています。L2正則化についての解説は様々なところで行われていますので、ここではソースの紹介のみに留まりますが、正答に大きく寄与するパラメータを選定するものと理解しておけば良いかと思います。
### 抜粋
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
for epoch in range(1, n_epochs+1):
...
loss = loss_fn(outputs, labels)
l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
loss = loss + l2_lambda * l2_norm
...
学習済みのモデルをロードする
以上で学習した結果は保存されていますので、新しくネットワークを生成して、パラメータをロードして評価用のネットワークを作成してみましょう。 本来は新しくネットワークを使うことなく、学習したネットワークをそのまま評価に使えるのですが、今回はsaveとloadが出来ているかを確認するために、この処理を追加しています。 モデルのロードの際には、セーブしたモデルと同じクラスをインスタンス化する必要がありますので、モデルの定義は保持しておくようにしましょう。
##### モデルをセーブする
### 抜粋
model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
# ...
model.train()
# ...
torch.save(model.state_dict(), weight_path)
##### モデルをロードする
### 抜粋
loaded_model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
# ...
loaded_model.load_state_dict(torch.load(weight_path, weights_only=True))
l
学習済みのモデルを評価する
評価をする際はパラメータをもうチューニングしないことを意味する model.eval()
をまず呼び出します。DataLoaderを利用して学習用のデータセットと評価用のデータセットからそれぞれデータを抽出し、正答数をカウントします。 今回のモデルは入力に対して10分類のいずれである確率を出力しますので、確率が最も高いインデックスと、正答のインデックスが一致しているかを判定し、正答数を加算します。
### 抜粋
loaded_model.eval()
validate(device, loaded_model, train_loader, val_loader)
### 抜粋
def validate(device, model, train_loader, val_loader):
for name, loader in [("train", train_loader), ("val", val_loader)]:
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs = imgs.to(device=device)
labels = labels.to(device=device)
outputs = model(imgs)
_, predicted = torch.max(outputs, dim=1)
total += labels.shape[0]
correct += int((predicted == labels).sum())
print("Accuracy {} : {:.2f}".format(name, correct/total))
評価方法のさらなる解説については、下記の記事が適しています。
以上が、ResNetみたいな(畳み込み層とスキップを備えたサブネットワークの多層構造からなる)モデルの組み方です。 商業利用に耐え得るモデルを組み上げる際は、もっと計画的かつ、論文等の調査を行い、学習に用いるハイパーパラメータ(学習率やL1/L2正則化の選定など)を決定する必要があります。 が、今回の記事でソースコードを書くイメージや、実装規模のイメージが付きましたら幸いです。
是非、みなさまの開発にお役立てください。