0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ResNetみたいなモデルをPyTorchでゼロから定義してCIFER-10の画像を分類するソースコードとその解説

Last updated at Posted at 2024-12-16

はじめてのPyTorchで、RasNetみたいなAIモデルをゼロからつくってみよう

最近、機械学習の勉強をはじめまして、PyTorch提供するモデルであるResNet等の画像分類AIをtorchvision経由で利用していたのですが(下記記事)、 ゼロからモデルを組み上げる場合どうすれば良いのだろう と思い、下記書籍で勉強しはじめました。今回の記事は、書籍の抜粋に近く恐縮なのですが、 ゼロからPython+PyTorchで画像分類をするモデルを定義し、その内容を整理するもの としました。


定義済みのモデルをまずは動かしてみたい、という方は下記をご参照ください。


参考にした書籍

書籍の方が正確で読みやすいと思いますので、ご興味あれば是非ご購読ください。

PyTorch実践入門 ~ ディープラーニングの基礎から実装へ

image0.jpeg

本記事により実現できる物体検出モデル

本記事で扱うモデルを実装することにより、 1000epochsのイテレーションの結果、学習用データセットに対しては84%、検証用データセットに対しても72%程度の正答率を実現できる物体検出器を構築 することができます。

2024-12-16 23:52:07.780146 epoch 980, Training loss 0.8238151277727483
2024-12-16 23:53:03.873066 epoch 985, Training loss 0.8236197303323185
2024-12-16 23:53:59.997969 epoch 990, Training loss 0.8230565896119608
2024-12-16 23:54:56.139902 epoch 995, Training loss 0.8226507447869577
2024-12-16 23:55:50.857574 epoch 1000, Training loss 0.8289129080065071
Accuracy train : 0.84
Accuracy val : 0.72

ソースコードと、その実行方法と結果

早速ですが、 ソースコード全文は以下となります。 実行にはPython仮想環境を準備し、必要なパッケージをあらかじめインストールしておく必要があります。

ソースコード全文

下記のソースコードが、 AIのモデルを定義し、モデルを織り成すハイパーパラメータを定義し、そのモデルを画像分類セットであるCIFER10で学習し、学習結果をCIFER10の検証セットで評価し、学習結果のモデルを保存するまで のソースコード全文です。実行する際は、下記のソースコードに加えて import from 節で利用しているパッケージを仮想環境へインストールする必要があります。これについては以降の ソースコードの実行方法 をご参照ください。まずは下記ソースコードを 「sample_conv.py」 という名前で保存しておきましょう。

CIFIR10は、32px x 32pxの非常に小さく低解像度な画像を、10種類のラベルデータに分類するデータセットです。 今回のように小さな規模のモデルを体験する のに適しています。 実用に足る解像度のモデルを開発する際には、幅広いバリエーションを扱う場合は1000種類の分類をもつ 224px x 224px の画像データセット「Imagenet」の方が向いています。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchinfo import summary
from matplotlib import pyplot as plt
import datetime

########
# Hyper Parameters
########
cifer10_class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
data_path = "./data-universioned/"
weight_path = './run/cifer10_sample.pt'
batch_size = 128
n_blocks = 10
learning_rate = 1e-2
l2_lambda = 1e-2
n_epochs = 1000

# Set Accelerator
if torch.cuda.is_available():
    device = 'cuda'
else:
    if torch.backends.mps.is_available():
        device = 'mps'
    else:
        device = 'cpu'

print("====================")
print(".to(device='{0}')".format(device))
print("====================")

########
# PIL images
########
cifer10_train = datasets.CIFAR10(data_path, train=True, download=True)
cifer10_val = datasets.CIFAR10(data_path, train=False, download=True)

# Tensor images
tensor_cifer10_train = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())
tensor_cifer10_val = datasets.CIFAR10(data_path, train=False, download=False, transform=transforms.ToTensor())
# calc mean and std
# imgs = torch.stack([img_t for img_t, _ in tensor_cifer10_train], dim=3)
# imgs_mean = imgs.view(3, -1).mean(dim=1)
# print("datasets mean:", imgs_mean)
# imgs_std = imgs.view(3, -1).std(dim=1)
# print("datasets std:", imgs_std)
### mean: tensor([0.4914, 0.4822, 0.4465])
### std: tensor([0.2470, 0.2435, 0.2616])

# Normalized images
transformed_cifer10_train = datasets.CIFAR10(data_path, 
                                             train=True, download=False, 
                                             transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4915, 0.4823, 0.4468),
                                                                     (0.2470, 0.2435, 0.2616))                                                 
                                             ]))
transformed_cifer10_val = datasets.CIFAR10(data_path, 
                                             train=False, download=False, 
                                             transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4915, 0.4823, 0.4468),
                                                                     (0.2470, 0.2435, 0.2616))                                                 
                                             ]))

########
# DataLoader
########
train_loader = torch.utils.data.DataLoader(transformed_cifer10_train, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(transformed_cifer10_val, batch_size=batch_size, shuffle=True)

########
# define Models (Sub-Net and Net)
########
class ResBlock(nn.Module):
    def __init__(self, n_chans):
        super(ResBlock, self).__init__()
        # create layers
        self.n_chans1 = n_chans
        self.conv1 = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
        self.conv2 = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
        self.conv3 = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
        self.conv_skip = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(num_features=n_chans)
        self.batch_norm2 = nn.BatchNorm2d(num_features=n_chans)
        self.batch_norm3 = nn.BatchNorm2d(num_features=n_chans)
        # init weight and bias
        torch.nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv3.weight, nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv_skip.weight, nonlinearity='relu')
        torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm1.bias)
        torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm2 .bias)
        torch.nn.init.constant_(self.batch_norm3.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm3.bias)

    def forward(self, x):
        # skip path
        input_1 = x.clone()
        # main path
        out = self.conv1(x)
        out = self.batch_norm1(out)
        out = torch.relu(out)
        out = self.conv2(x)
        out = self.batch_norm2(out)
        out = torch.relu(out)
        out = self.conv3(x)
        out = self.batch_norm3(out)
        # convine
        out += input_1
        out = torch.relu(out)
        return out

class Net(nn.Module):
    def __init__(self, n_chans1=32, n_blocks=10, n_out=10):
        super().__init__()
        self.n_chans1 = n_chans1
        self.n_out = n_out
        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
        self.resblocks = nn.Sequential()
        for idx_block in range(1, n_blocks + 1):
            self.resblocks.add_module("idx{}".format(idx_block), ResBlock(n_chans=n_chans1))
        self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
        self.fc2 = nn.Linear(32, n_out)
    
    def forward(self, x):
        out = F.max_pool2d(torch.relu(self.conv1(x)), 2)
        out = self.resblocks(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 8 * 8 * self.n_chans1)
        out = torch.relu(self.fc1(out))
        out = self.fc2(out)
        return out

########
# def Training
########
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
    for epoch in range(1, n_epochs+1):
        loss_train = 0.0
        if epoch == 1:
            print('{} epoch 0, Starting ...'.format(
                datetime.datetime.now()))

        for imgs, labels in train_loader:
            imgs = imgs.to(device=device)
            labels = labels.to(device=device)
            outputs = model(imgs)

            loss = loss_fn(outputs, labels)
            l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            loss = loss + l2_lambda * l2_norm

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()

        if epoch == 1 or epoch % 5 == 0:
            print('{} epoch {}, Training loss {}'.format(
                datetime.datetime.now(),
                epoch, loss_train / len(train_loader)
            ))
            
########
# def Validation
########
def validate(device, model, train_loader, val_loader):
    for name, loader in [("train", train_loader), ("val", val_loader)]:
        correct = 0
        total = 0
        with torch.no_grad():
            for imgs, labels in loader:
                imgs = imgs.to(device=device)
                labels = labels.to(device=device)
                outputs = model(imgs)
                _, predicted = torch.max(outputs, dim=1)
                total += labels.shape[0]
                correct += int((predicted == labels).sum())
            print("Accuracy {} : {:.2f}".format(name, correct/total))

########
# exec learinig and validation
########
# learning params and define model
model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
summary(model, input_size=(batch_size, 
                           tensor_cifer10_train[0][0].shape[0], 
                           tensor_cifer10_train[0][0].shape[1], 
                           tensor_cifer10_train[0][0].shape[2]), 
                           device=device,
                           depth=3)
model = model.to(device=device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

# training
model.train()
training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader)
torch.save(model.state_dict(), weight_path)

# validation
loaded_model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
loaded_model = loaded_model.to(device=device)
loaded_model.load_state_dict(torch.load(weight_path, ))
loaded_model.eval()
validate(device, loaded_model, train_loader, val_loader)

print("quit()")
quit()

ソースコードの実行方法

上記のスクリプトを 「sample_conv.py」 という名前で保存し、下記の手順によりスクリプトを実行します。まず、Pythonの仮想環境をAnacondaのconda createコマンドにより作成し、次に、必要なパッケージをpipコマンドによりインストールし、最後にpythonコマンドによりスクリプトを実行します。

# Python3.12ベースの仮想環境をAnacondaにより作成する
$ conda create -n learning_conv python=3.12
# Proceed ([y]/n)? y
# To activate this environment, use
#
#     $ conda activate learning_conv
# ...

# 仮想環境を有効化する
$ conda activate learning_conv
(learning_conv) $ python --version
# Python 3.12.8

# 必要なパッケージをインストールする
(learning_conv) $ pip install torch torchvision torchaudio torchinfo           
(learning_conv) $ pip install matplotlib datetime

# スクリプトを実行する
(learning_conv) $ python sample_conv.py 
####
# 実行結果が表示される
####

実行結果

スクリプトの実行に成功すると、下記のような結果を得ることができます。まず、どのデバイスの上でモデルを実行するか(mpscudacpu)の選択が行われ、続いて学習と評価に利用するデータセットのダウンロードが実行されます。その後、定義したモデルのサマリーが出力され、モデルの学習が開始されます。最後に、学習済みのモデルによる推論により、学習用データセットと検証用データセットに対してどの程度の正答率となるかを表示し終了となります。 今回の例では学習用データセットに対して84%は、検証用データセットに対しては72%の正答率でした。 モデルの組み方や学習方法によっては、調整したパラメータが学習用データセットに特化しすぎて、検証用データセットの正答率が低いオーバーフィッティングの状態となることがありますので、必ず、学習用データに含まれていないデータセットで検証する必要があります。

====================
.to(device='mps')
====================
Files already downloaded and verified
Files already downloaded and verified
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Net                                      [128, 10]                 --
├─Conv2d: 1-1                            [128, 32, 32, 32]         896
├─Sequential: 1-2                        [128, 32, 16, 16]         --
│    └─ResBlock: 2-1                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-1                  [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-2             [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-3                  [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-4             [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-5                  [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-6             [128, 32, 16, 16]         64
│    └─ResBlock: 2-2                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-7                  [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-8             [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-9                  [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-10            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-11                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-12            [128, 32, 16, 16]         64
│    └─ResBlock: 2-3                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-13                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-14            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-15                 [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-16            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-17                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-18            [128, 32, 16, 16]         64
│    └─ResBlock: 2-4                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-19                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-20            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-21                 [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-22            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-23                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-24            [128, 32, 16, 16]         64
│    └─ResBlock: 2-5                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-25                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-26            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-27                 [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-28            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-29                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-30            [128, 32, 16, 16]         64
│    └─ResBlock: 2-6                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-31                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-32            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-33                 [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-34            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-35                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-36            [128, 32, 16, 16]         64
│    └─ResBlock: 2-7                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-37                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-38            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-39                 [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-40            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-41                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-42            [128, 32, 16, 16]         64
│    └─ResBlock: 2-8                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-43                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-44            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-45                 [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-46            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-47                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-48            [128, 32, 16, 16]         64
│    └─ResBlock: 2-9                     [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-49                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-50            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-51                 [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-52            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-53                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-54            [128, 32, 16, 16]         64
│    └─ResBlock: 2-10                    [128, 32, 16, 16]         1,024
│    │    └─Conv2d: 3-55                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-56            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-57                 [128, 32, 16, 16]         9,216
│    │    └─BatchNorm2d: 3-58            [128, 32, 16, 16]         64
│    │    └─Conv2d: 3-59                 [128, 32, 16, 16]         1,024
│    │    └─BatchNorm2d: 3-60            [128, 32, 16, 16]         64
├─Linear: 1-3                            [128, 32]                 65,568
├─Linear: 1-4                            [128, 10]                 330
==========================================================================================
Total params: 191,594
Trainable params: 191,594
Non-trainable params: 0
Total mult-adds (G): 3.82
==========================================================================================
Input size (MB): 1.57
Forward/backward pass size (MB): 536.91
Params size (MB): 0.73
Estimated Total Size (MB): 539.21
==========================================================================================
2024-12-16 20:51:35.152977 epoch 0, Starting ...
2024-12-16 20:51:45.996456 epoch 1, Training loss 27.925368345606966
2024-12-16 20:52:29.703423 epoch 5, Training loss 15.228352090586787
2024-12-16 20:53:25.581729 epoch 10, Training loss 7.535380253706442
2024-12-16 20:54:21.600999 epoch 15, Training loss 4.021200152614233
2024-12-16 20:55:17.960491 epoch 20, Training loss 2.3994958912929913
...
2024-12-16 23:52:07.780146 epoch 980, Training loss 0.8238151277727483
2024-12-16 23:53:03.873066 epoch 985, Training loss 0.8236197303323185
2024-12-16 23:53:59.997969 epoch 990, Training loss 0.8230565896119608
2024-12-16 23:54:56.139902 epoch 995, Training loss 0.8226507447869577
2024-12-16 23:55:50.857574 epoch 1000, Training loss 0.8289129080065071
Accuracy train : 0.84
Accuracy val : 0.72

ソースコードの抜粋と解説

それではモデルを構成、学習、検証するためのソースコードの各部を見ていきましょう。

モデルを可視化する

まず、定義したモデル(nn.Moduleを継承したクラス)はtorchinfosummaryメソッドにより可視化することができます。可視化には、 モデル に加えて、 バッチサイズ (1回の推論の入力データの組:今回は128枚を同時に処理します、メモリの容量などに合わせて調整ください)、 入力データの形状 の指定が必要です。今回はCIFER-10という RGB3ch(tensor_cifer10_train[0][0].shape[0]) x 32pix(tensor_cifer10_train[0][0].shape[1]) x 32pix(tensor_cifer10_train[0][0].shape[2]) のデータを処理します。最後に、 モデルをどの解像度まで表現するかを「depth」 で指定しましょう。

この可視化により、パラメータの個数、メモリの消費量、モデルを構成する各層の入出力データの形状を確認することができます。

### 抜粋
# learning params and define model
model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
summary(model, input_size=(batch_size, 
                           tensor_cifer10_train[0][0].shape[0], 
                           tensor_cifer10_train[0][0].shape[1], 
                           tensor_cifer10_train[0][0].shape[2]), 
                           device=device,
                           depth=3)

これにより、データ量などが下記のように表示されます。

### 抜粋
==========================================================================================
Total params: 191,594
Trainable params: 191,594
Non-trainable params: 0
Total mult-adds (G): 3.82
==========================================================================================
Input size (MB): 1.57
Forward/backward pass size (MB): 536.91
Params size (MB): 0.73
Estimated Total Size (MB): 539.21
==========================================================================================

モデルを実行するデバイスを選択する

近年のAIモデルは非常に巨大(多層)となっており、調整するパラメータも大量です。そのため、Modelと入力データをCUDA(NVIDIA)やMPS(Apple)等のアクセラレータに載せて演算することが一般的です。 下記のソースコードにより、アクセラレータを利用できるか確認し、.to(device=*) のメソッドに指定するデバイス名を決定しましょう。 なお、NVIDIAのGPU等はワークステーションに複数搭載できますので、より複雑なモデルを実行する際にはcuda:0cuda:1など細かいスケジューリングをすることも可能です。

### 抜粋
if torch.cuda.is_available():
    device = 'cuda'
else:
    if torch.backends.mps.is_available():
        device = 'mps'
    else:
        device = 'cpu'

### 抜粋
model = model.to(device=device)               # モデルをアクセラレータに載せる
### 抜粋
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
    for epoch in range(1, n_epochs+1):
        ...
        for imgs, labels in train_loader:
            imgs = imgs.to(device=device)     # 入力データをアクセラレータに載せる
            labels = labels.to(device=device) # ラベルデータをアクセラレータに載せる
            outputs = model(imgs)             # アクセラレータ上で計算する

torchvisionのdatasetsとDataLoaderにより利用するデータを整備する

PyTorchにより定義されたAIのモデルを実行する際、 1回の推論で1個のデータのみを扱うことはなく、バッチと呼ばれる単位の複数データをモデルに与えることが一般的です。 今回の例では128個のデータを同時に処理しています。このようにデータセットから複数のデータを読み出すことを、PyTorchでは datasetsによりデータ一式を定義 し、 DataLoaderにより読み出すストリームを定義 します。

### 抜粋
transformed_cifer10_train = datasets.CIFAR10(data_path, ...
### 抜粋
train_loader = torch.utils.data.DataLoader(transformed_cifer10_train, batch_size=batch_size, ...

また、PyTorchのdatasetsには、データを読み出す際にデータを加工するtransformsという処理を指定することができます。transformsは 「Compose」でどのようにデータを加工するかのリストを定義 できます。今回はComposeの内容として 「ToTensor()」によるRGB256諧調のデータを0〜1のテンソルへ正規化する変換と、「Normalize()」による平均値と標準偏差を使ったメリハリのあるデータの生成 を定義しています。

### 抜粋
transformed_cifer10_train = datasets.CIFAR10(data_path, 
                                             train=True, download=False, 
                                             transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4915, 0.4823, 0.4468),
                                                                     (0.2470, 0.2435, 0.2616))                                                 
                                             ]))

なお、Normalizeのパラメータは下記コードにより事前に生成しました。

### Tensor images
# tensor_cifer10_train = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())
# tensor_cifer10_val = datasets.CIFAR10(data_path, train=False, download=False, transform=transforms.ToTensor())
### calc mean and std
# imgs = torch.stack([img_t for img_t, _ in tensor_cifer10_train], dim=3)
# imgs_mean = imgs.view(3, -1).mean(dim=1)
# print("datasets mean:", imgs_mean)
# imgs_std = imgs.view(3, -1).std(dim=1)
# print("datasets std:", imgs_std)
### mean: tensor([0.4914, 0.4822, 0.4465])
### std: tensor([0.2470, 0.2435, 0.2616])

モデルをnn.Moduleにより定義する

PyTorchで扱うモデルは、nn.Moduleを継承したクラスとして定義します。ここで定義するクラスでは、計算処理の流れをforward(self, ...)で定義し(例:output = gx(fx(input)))、処理に利用するニューロンを__init__(self, ...)で生成します。生成されたニューロンが持つ演算のパラメータは学習ごとに微調整され、学習が終わった際に適切なパラメータを保持した関数になります。

### 抜粋
class Net(nn.Module):
    def __init__(self, n_chans1=32, n_blocks=10, n_out=10):
        super().__init__()
        self.n_chans1 = n_chans1
        self.n_out = n_out
        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
        self.resblocks = nn.Sequential()
        for idx_block in range(1, n_blocks + 1):
            self.resblocks.add_module("idx{}".format(idx_block), ResBlock(n_chans=n_chans1))
        self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
        self.fc2 = nn.Linear(32, n_out)
    
    def forward(self, x):
        out = F.max_pool2d(torch.relu(self.conv1(x)), 2)
        out = self.resblocks(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 8 * 8 * self.n_chans1)
        out = torch.relu(self.fc1(out))
        out = self.fc2(out)
        return out

また、ネットワークを定義する際は、nn.Moduleで事前に定義したネットワークを再利用することができます。下記の定義では__init__にて、事前にnn.Moduleを定義したResBlockクラスを生成し、nn.Sequential()で生成した処理列にadd_moduleで結合しています。これにより今回のネットワークは10層のResBlockから構成されます。

スクリーンショット 2024-12-16 18.28.53.png

Net                                      [128, 10]                 --
├─Conv2d: 1-1                            [128, 32, 32, 32]         896
├─Sequential: 1-2                        [128, 32, 16, 16]         --
│    └─ResBlock: 2-1                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-2                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-3                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-4                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-5                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-6                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-7                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-8                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-9                     [128, 32, 16, 16]         --
│    └─ResBlock: 2-10                    [128, 32, 16, 16]         --
├─Linear: 1-3                            [128, 32]                 65,568
├─Linear: 1-4                            [128, 10]                 330

サブネットをnn.Moduleにより定義してネットに組み込む

上記で利用したネットワークResBlockは下記のようにnn.Moduleを継承したクラスにより定義できます。このクラスも呼び出し元クラスNet__init__でインスタンス化された後に、学習フェーズでパラメータが調整されます。ここで定義するサブネットは RasNetらしく、畳み込み層(Conv2d)、バッチ正規化層(BatchNorm2d)、活性化層(ReLu)を3つ重ね、最後に入力を加算 します。

このようにサブネットワークを定義することは、メインのネットワークを定義する際にサブネットワークの繰り返しを活用できますので、 複雑なネットワークや深いネットワークの定義をする際に有用です。

### 抜粋
class ResBlock(nn.Module):
    def __init__(self, n_chans):
        super(ResBlock, self).__init__()
        # create layers
        self.n_chans1 = n_chans
        self.conv1 = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
        self.conv2 = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
        self.conv3 = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
        self.conv_skip = nn.Conv2d(n_chans, n_chans, kernel_size=1, padding=0, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(num_features=n_chans)
        self.batch_norm2 = nn.BatchNorm2d(num_features=n_chans)
        self.batch_norm3 = nn.BatchNorm2d(num_features=n_chans)
        # init weight and bias
        torch.nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv3.weight, nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv_skip.weight, nonlinearity='relu')
        torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm1.bias)
        torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm2 .bias)
        torch.nn.init.constant_(self.batch_norm3.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm3.bias)

    def forward(self, x):
        # skip path
        input_1 = x.clone()
        # main path
        out = self.conv1(x)
        out = self.batch_norm1(out)
        out = torch.relu(out)
        out = self.conv2(x)
        out = self.batch_norm2(out)
        out = torch.relu(out)
        out = self.conv3(x)
        out = self.batch_norm3(out)
        # convine
        out += input_1
        out = torch.relu(out)
        return out

モデルの学習を実施する

上記の内容でネットワークを定義できましたので、学習を行いましょう。学習は下記のソースコードにより実現することができます。まず、ネットワークをmodel.train()で学習モードに切り替えます。そして学習した後はパラメータを今後に活かすため、torch.save(model.state_dict, ...)を利用してモデルのパラメータをファイルへ保存します(※ここで保存したパラメータは推論時にロードして利用しています)。

### 抜粋
model.train()
training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader)
torch.save(model.state_dict(), weight_path)

仮の回答を得て、正答との差分lossを算出する

学習ではまず、DataLoaderにより定義した train_loader からバッチサイズ幅のデータを imgs labels を取り出し、これをモデルへ入力して output を得ます。ここでモデルはパラメータ調整されていませんので output と正答である labels は一致しません。

### 抜粋
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
    for epoch in range(1, n_epochs+1):
        loss_train = 0.0
        if epoch == 1:
            ...

        for imgs, labels in train_loader:
            imgs = imgs.to(device=device)
            labels = labels.to(device=device)
            outputs = model(imgs)

            ...
            
        if epoch == 1 or epoch % 5 == 0:
            ...

そこで期待するモデルを実現するために outputslabels の差分から算出した loss を最小化するようにパラメータの調整を行います。lossの計算には損失関数として定義したnn.CrossEntropyLoss()で算出し、パラメータはoptimizerとして選択したoptim.SGD(model.parameters(), ...)により調整します。optimizerには、どのくらいの刻みで調整するかを決めるlr=learning_rateを指定します。

各バッチ処理ごとにloss_fnによりlossを計算し、optimizerの保持する勾配をoptimizer.zero_grad()により初期化、loss.backward()により誤差の逆伝搬を実施し、その結果をoptimizer.step()によりモデルのパラメータへ反映します。

### 抜粋
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
    for epoch in range(1, n_epochs+1):
        ...
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs = imgs.to(device=device)
            labels = labels.to(device=device)
            outputs = model(imgs)

            loss = loss_fn(outputs, labels)
            ...
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()
        ...

プログラムの下記箇所はL2正則化と呼ばれる最適化のテクニックを意味しています。L2正則化についての解説は様々なところで行われていますので、ここではソースの紹介のみに留まりますが、正答に大きく寄与するパラメータを選定するものと理解しておけば良いかと思います。

### 抜粋
def training_loop(device, n_epochs, optimizer, model, loss_fn, l2_lambda, train_loader):
    for epoch in range(1, n_epochs+1):
        ...
            loss = loss_fn(outputs, labels)
            l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            loss = loss + l2_lambda * l2_norm
        ...

学習済みのモデルをロードする

以上で学習した結果は保存されていますので、新しくネットワークを生成して、パラメータをロードして評価用のネットワークを作成してみましょう。 本来は新しくネットワークを使うことなく、学習したネットワークをそのまま評価に使えるのですが、今回はsaveとloadが出来ているかを確認するために、この処理を追加しています。 モデルのロードの際には、セーブしたモデルと同じクラスをインスタンス化する必要がありますので、モデルの定義は保持しておくようにしましょう。

##### モデルをセーブする
### 抜粋
model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
# ...
model.train()
# ...
torch.save(model.state_dict(), weight_path)

##### モデルをロードする
### 抜粋
loaded_model = Net(n_chans1=32, n_blocks=n_blocks, n_out=len(cifer10_class_names))
# ...
loaded_model.load_state_dict(torch.load(weight_path, weights_only=True))
l

学習済みのモデルを評価する

評価をする際はパラメータをもうチューニングしないことを意味する model.eval() をまず呼び出します。DataLoaderを利用して学習用のデータセットと評価用のデータセットからそれぞれデータを抽出し、正答数をカウントします。 今回のモデルは入力に対して10分類のいずれである確率を出力しますので、確率が最も高いインデックスと、正答のインデックスが一致しているかを判定し、正答数を加算します。

### 抜粋
loaded_model.eval()
validate(device, loaded_model, train_loader, val_loader)

### 抜粋
def validate(device, model, train_loader, val_loader):
    for name, loader in [("train", train_loader), ("val", val_loader)]:
        correct = 0
        total = 0
        with torch.no_grad():
            for imgs, labels in loader:
                imgs = imgs.to(device=device)
                labels = labels.to(device=device)
                outputs = model(imgs)
                _, predicted = torch.max(outputs, dim=1)
                total += labels.shape[0]
                correct += int((predicted == labels).sum())
            print("Accuracy {} : {:.2f}".format(name, correct/total))

評価方法のさらなる解説については、下記の記事が適しています。


以上が、ResNetみたいな(畳み込み層とスキップを備えたサブネットワークの多層構造からなる)モデルの組み方です。 商業利用に耐え得るモデルを組み上げる際は、もっと計画的かつ、論文等の調査を行い、学習に用いるハイパーパラメータ(学習率やL1/L2正則化の選定など)を決定する必要があります。 が、今回の記事でソースコードを書くイメージや、実装規模のイメージが付きましたら幸いです。

是非、みなさまの開発にお役立てください。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?