1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorchでCIFAR10画像認識モデルの実験してみる

Last updated at Posted at 2024-10-12

はじめに

PyTorchを使っていろんなパターンのCIFAR10画像認識モデルの実験をし、モデルの構築法や各種手法を学んだので、備忘録として書いておく。

一部のコードの解説をしていきます。

目次

1.CIFAR10の仕様確認
2.正規化
3.実験
4.最後に

1.CIFAR10の仕様確認

CIFAR10のダウンロードを行います。その際にテンソルへ変換しています。

from torchvision import datasets, transforms
data_path = "./data"
cifar10 = datasets.CIFAR10(data_path, train=True, download=True, transform=transforms.ToTensor())
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True, transform=transforms.ToTensor())

画像枚数を確認します。

len(cifar10), len(cifar10_val)
#実行結果(50000, 10000)

クラスの確認をします。今回はこの10分類を行います。

cifar10.classes
#実行結果
['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

データセットの画像を表示してみます。num_imagesで表示枚数を変えられます。

import math
num_images = 10
imgs, labels = [], []
for i in range(num_images):
    img, label = cifar10[i]
    imgs.append(img)
    labels.append(label)
cols = math.ceil(math.sqrt(num_images))
rows = math.ceil(num_images / cols)  

fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
axes = axes.flatten()

for i, (img, label) in enumerate(zip(imgs, labels)):
    axes[i].imshow(img.permute(1, 2, 0))
    axes[i].set_title(cifar10.classes[label])
    axes[i].axis('off')

for j in range(i + 1, rows * cols):
    axes[j].axis('off')

plt.tight_layout()
plt.show()

image.png
テンソルの次元数を確認します。

imgs = torch.stack([img for img, _ in cifar10], dim=3)
imgs.shape
#実行結果
torch.Size([3, 32, 32, 50000])

正規化のために平均と標準偏差を計算します。

imgs.view(3, -1).mean(dim=1), imgs.view(3, -1).std(dim=1)
#実行結果
(tensor([0.4914, 0.4822, 0.4465]), tensor([0.2470, 0.2435, 0.2616]))

2.正規化

正規化をして平均0標準偏差1の分布にします。これによって活性化関数を効果的に扱えます。

cifar10 = datasets.CIFAR10(data_path, train=True, download=False, 
                           transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                            ]))
cifar10_val = datasets.CIFAR10(data_path, train=False, download=False,
                                 transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                             ]))

3.実験

まずは全結合層のみのモデルで実験してみます。ハイパーパラメータは固定します。

学習率 エポック数 バッチサイズ
1e-2 100 64
モデル名 構造 層数
Model_2L 全結合 2層
Model_4L 全結合 4層

Model_2L

class Model_2L(nn.Module):
    def __init__(self, learning_rate=1e-2, name="model_2L"):
        super(Model_2L, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Linear(3072, 512),
            nn.Tanh(),
            nn.Linear(512, 10)
        )
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.network(x)

Model_4L

class Model_4L(nn.Module):
    def __init__(self, learning_rate=1e-2, name="model_4L"):
        super(Model_4L, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Linear(3072, 2048),
            nn.Tanh(),
            nn.Linear(2048, 1024),
            nn.Tanh(),
            nn.Linear(1024, 512),
            nn.Tanh(),
            nn.Linear(512, 10)
        )
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):

学習用の関数

train_loader = torch.utils.data.DataLoader(cifar10, batch_size=64, shuffle=True)
def train_model(model, n_epochs, is_conv=False):
    losses = []
    for epoch in range(n_epochs):
        for imgs, labels in train_loader:
            imgs = imgs.to(device=device)
            labels = labels.to(device=device)
            if is_conv:
                outputs = model(imgs)
            else:
                outputs = model(imgs.view(imgs.shape[0], -1))
            loss = model.loss_fn(outputs, labels.to(device=device))
            model.optimizer.zero_grad()
            loss.backward()
            model.optimizer.step()
        
        losses.append(float(loss))
        print(f"Epoch {epoch}, Loss {float(loss)}")
    return losses

評価用の関数

val_loader = torch.utils.data.DataLoader(cifar10_val, batch_size=64, shuffle=False)
def evaluate_accuracy(model, is_conv=False):
    model_name = getattr(model, 'name', model.__class__.__name__)
    with torch.no_grad():
        correct = 0
        total = 0
        for imgs, labels in val_loader:
            imgs = imgs.to(device=device)
            labels = labels.to(device=device)
            if is_conv:
                outputs = model(imgs)
            else:
                outputs = model(imgs.view(imgs.shape[0], -1))
            _, predicted = torch.max(outputs, dim=1)
            total += labels.shape[0]
            correct += int((predicted == labels).sum())
        accuracy = correct / total
        print(f'{model_name} val_Accuracy: {accuracy}')

        correct = 0
        total = 0
        for imgs, labels in train_loader:
            imgs = imgs.to(device=device)
            labels = labels.to(device=device)
            if is_conv:
                outputs = model(imgs)
            else:
                outputs = model(imgs.view(imgs.shape[0], -1))
            _, predicted = torch.max(outputs, dim=1)
            total += labels.shape[0]
            correct += int((predicted == labels).sum())
        accuracy = correct / total
        print(f'{model_name} train_Accuracy: {accuracy}')

結果

モデル名 構造 層数 評価用正解率 訓練用正解率
Model_2L 全結合 2層 0.4691 0.99934
Model_4L 全結合 4層 0.4871 1.0

ロス
image.png

評価用データセットの正解率は5割ぐらいで、少し4層のモデルの方が高く、訓練用はほぼ1とかなり過学習になっていることがわかります。ロスが下がりきるまでに学習を止めれば違う結果になるのでしょうが、モデルの構造を変えて過学習を抑えられないか模索してみます。

次は、バッチ正規化を導入してみます。バッチ正規化とは、層の出力を正規化することで活性化関数を有効に使う方法です。入力画像を正規化したのと同じ目的です。
さっきのモデルにバッチ正規化を入れたモデルで学習してみます。

モデル名 構造 層数 バッチ正規化
Model_2L_BN 全結合のみ 2層 あり
Model_4L_BN 全結合のみ 4層 あり

Model_2L_BN

class Model_2L_BN(nn.Module):
    def __init__(self, learning_rate=1e-2, name="model_2L_BN"):
        super(Model_2L_BN, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Linear(3072, 512),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Linear(512, 10)
        )
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        return self.network(x)

Model_4L_BN

class Model_4L_BN(nn.Module):
    def __init__(self, learning_rate=1e-2, name="model_4L_BN"):
        super(Model_4L_BN, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Linear(3072, 2048),
            nn.BatchNorm1d(2048),
            nn.Tanh(),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Linear(512, 10)
        )
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        return self.network(x)

結果

モデル名 構造 層数 バッチ正規化 評価用正解率 訓練用正解率
Model_2L_BN 全結合 2層 あり 0.4214 0.83952
Model_4L_BN 全結合 4層 あり 0.507 0.99922

ロス
image.png

バッチ正規化の導入前と比べて、過学習が収まっているように見えます。2層のモデルは評価用正解率が下がっていますがパラメータ数が少ないとかえって汎化性能が下がるからでしょうか。
次はドロップアウトとL2正則化を導入します。前者は確率で層の出力を0にし、後者はパラメータが大きくなるのを抑えることで過学習を防ぎます。ここからは4層のモデルで実験します。

モデル名 構造 層数 バッチ正規化 正則化
Model_4L_BN_DO_05 全結合 2層 あり ドロップアウトP=0.5
Model_4L_BN_ 全結合 4層 あり L2正則化1e-4

Model_4L_BN_DO_05

class Model_4L_BN_DO_05(nn.Module):
    def __init__(self, learning_rate=1e-2, name="model_4L_BN_DO_05"):
        super(Model_4L_BN_DO_05, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Linear(3072, 2048),
            nn.BatchNorm1d(2048),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        return self.network(x)

Model_4L_L2_1e_4

class Model_4L_BN_L2_1e_4(nn.Module):
    def __init__(self, learning_rate=1e-2, name="model_4L_BN_L2_1e_4"):
        super(Model_4L_BN_L2_1e_4, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Linear(3072, 2048),
            nn.BatchNorm1d(2048),
            nn.Tanh(),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Linear(512, 10)
        )
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate, weight_decay=1e-4)
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        return self.network(x)

結果

モデル名 構造 層数 バッチ正規化 正則化 評価用正解率 訓練用正解率
Model_4L_BN_DO_05 全結合 4層 あり ドロップアウト0.5 0.454 0.6086
Model_4L_BN_L2_1e_4 全結合 4層 あり L2正則化1e-4 0.5012 0.99762

ロス
image.png

ドロップアウトのモデルは過学習がかなり抑えられているが、反面正解率が下がり、L2正則化はほぼ変わらない結果になった。ロスをみるとドロップアウトの方は学習が進んでいないように見える。

最後に畳み込みを導入したモデルで実験する。画像は近傍の画素に高い相関があることと、畳み込みは画像中の物体の位置に対する不変性があることから、画像認識タスクに適した構造です。バッチ正規化の有無で実験してみます。

モデル名 構造 層数 バッチ正規化 正則化
Model_Conv_4L 畳み込み 4層 なし なし
Model_Conv_4L_BN 畳み込み 4層 あり なし

Model_Conv_4L

class Model_Conv_4L(nn.Module):
    def __init__(self, learning_rate=1e-2, name="model_Conv_4L"):
        super(Model_Conv_4L, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 8, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(8 * 8 * 8, 32),
            nn.Tanh(),
            nn.Linear(32, 10)
        )
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        return self.network(x)

Model_Conv_4L_BN

class Model_Conv_4L_BN(nn.Module):
    def __init__(self, learning_rate=1e-2, name="model_Conv_4L_BN"):
        super(Model_Conv_4L_BN, self).__init__()
        self.name = name
        self.network = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.Tanh(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.Tanh(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(8 * 8 * 8, 32),
            nn.BatchNorm1d(32),
            nn.Tanh(),
            nn.Linear(32, 10)
        )
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, x):
        return self.network(x)

結果

モデル名 構造 層数 バッチ正規化 正則化 評価用正解率 訓練用正解率
Model_Conv_4L 畳み込み 4層 なし なし 0.6612 0.7625
Model_Conv_4L_BN 畳み込み 4層 あり なし 0.6494 0.77922

ロス
image.png
畳み込みを入れたモデルはやはりこれまでの全結合モデルよりかなり精度は高くなった。バッチ正規化を入れないほうが良い結果となったので、もっと層を増やさないと意味がないのかもしれない。

これまでの結果をまとめるとこうなる。

モデル名 構造 層数 バッチ正規化 正則化 評価用正解率 訓練用正解率
Model_2L 全結合 2層 なし なし 0.4691 0.99934
Model_4L 全結合 4層 なし なし 0.4871 1.0
Model_2L_BN 全結合 2層 あり なし 0.4214 0.83952
Model_4L_BN 全結合 4層 あり なし 0.507 0.99922
Model_4L_BN_DO_05 全結合 4層 あり ドロップアウト0.5 0.454 0.6086
Model_4L_BN_L2_1e_4 全結合 4層 あり L2正則化1e-4 0.5012 0.99762
Model_Conv_4L 畳み込み 4層 なし なし 0.6612 0.7625
Model_Conv_4L_BN 畳み込み 4層 あり なし 0.6494 0.77922

4.最後に

この実験を通して、PyTorchへの理解が深まった。今後は物体検出を行う予定。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?