10
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorchを用いたCNNによる画像認識入門

Last updated at Posted at 2022-12-10

はじめに

こんにちは! myjlab advent calendar 2022 11日目の記事です。

昨日は@yuzu_mさんのOpenCVで画像処理(平滑化)という記事でした。是非ご覧ください!

11日目の今日は、『CNN(畳み込みニューラルネットワーク)』についてです。実際にPyTorchを用いて、PyTorchチュートリアルの実装も行なってみました。よろしくお願いします。

余談

ちなみにこの記事のタイトルは、今話題のOpenAIが発表したChatGPTに考えてもらいました。
(ChatGPTについては、@yoshiki495 君が6日目の記事に書いていたのでぜひ)
image1.png
また、『バズりそうな』を付け加えてみました。採用はしませんでしたが、確かにバズりそうなタイトルで技術の進化を感じました。image2.png
それでは本題に入ります!

CNN(Convolutional Neural Network)とは?

画像処理分野において多く用いられるニューラルネットワークで、日本語では『畳み込みニューラルネットワーク』とも呼ばれるものです。数値情報として3次元の情報をもつ画像データに適したモデルで、畳み込み(フィルタを用いて画像から特徴を抽出する操作)、プーリング(画像サイズの圧縮)の2つに特徴があります。

より詳しく、正確に知りたい方は、Convolutional Neural Networkとは何なのかという記事がわかり易かったので是非。

作るもの

PyTorchで、CNN(Convolutional Neural Network)を用いて、CIFAR-10の画像を分類する

CIFAR-10とは?

ラベル「0」: airplane(飛行機)
ラベル「1」: automobile(自動車)
ラベル「2」: bird(鳥)
ラベル「3」: cat(猫)
ラベル「4」: deer(鹿)
ラベル「5」: dog(犬)
ラベル「6」: frog(カエル)
ラベル「7」: horse(馬)
ラベル「8」: ship(船)
ラベル「9」: truck(トラック)
という10種類の「物体カラー写真」(乗り物や動物など)の画像データセットである。
CIFAR-10データセット全体は、5万枚の訓練データ用(画像とラベル)1万枚のテストデータ用(画像とラベル)合計6万枚で構成される
参照元:CIFAR-10:物体カラー写真(乗り物や動物など)の画像データセット

環境・使ったもの

  • Google Colabolatory
  • Python (PyTorch,matplotlib,numpy)
  • CIFAR-10

実装

必要なもののインポート

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

データの下処理

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

# 訓練用データ(50000件の訓練用データ)を保存
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

# テスト用データ(10000件のテスト用データ)を保存
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# 分類の定義
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

今回使うデータは既に、訓練用データと、テスト用データが分かれています(train=False train=Trueで使い分け可能)

batch_size(バッチサイズ)とは、1回に何個のdataを使用するかを表しています。

CIFAR-10の画像と、ラベルを見てみる

# 画像の表示
def imshow(img):
    img = img / 2 + 0.5   
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 学習画像をランダムに取得
dataiter = iter(trainloader)
images, labels = next(dataiter)
# 画像の表示
imshow(torchvision.utils.make_grid(images))
# バッチサイズの表示
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

このように画像とラベルが表示されます。32×32 pixelなので結構ぼやけてます。
image3.png

モデルの定義

class Net(nn.Module):
    # 利用するレイヤーや初期設定したい内容の記述
    def __init__(self):
        super().__init__()
        # 畳み込み層(入力チャネル数,出力チャネル数,カーネルサイズ)
        self.conv1 = nn.Conv2d(3, 6, 5) 

        # Maxプーリング層(カーネルサイズ,カーネルサイズ)
        self.pool = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(6, 16, 5)
        # 線形層(入力チャネル数,出力チャネル数)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # 畳み込み層は三次元のため、バラバラに分解し、一元化し線形層に入力
        x = torch.flatten(x, 1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
net = Net()

実際の学習部分

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

torch.optimに定義されている最適化アルゴリズムの中で、今回は、optim.SGDを使用しています。他には、Adamなどがあります。

for epoch in range(2):  # データセットを複数回ループ
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # inputを取得する。dataは[inputs, labels]のリスト
        inputs, labels = data
        # モデルに保存されているパラメータ勾配をゼロにする
        optimizer.zero_grad()
        # 予測結果を出力
        outputs = net(inputs)
        # 誤差関数から、誤差を計測
        loss = criterion(outputs, labels)
        # 誤差をフィードバック
        loss.backward()
        optimizer.step()
        # 統計を出力
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

検証

全体の精度

# テスト画像による精度の検証
correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
# 10000枚のテスト画像に対するCNNの精度
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

このような結果になりました。

クラスごとの精度の出力

# 各クラスの予測数のカウント
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)

        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    # クラスごとの精度
    print(f'Accuracy for class:{classname:5s} is {accuracy:.1f} %')

こちらはこのような結果になりました。

追加 (エポック数を変えてみた)

for epoch in range(2):を、for epoch in range(10):に変更したところ全体のAccuracyは

クラスごとのAcccracyは

となり、精度が上がりました。他にも最適化アルゴリズムをAdamにしたり、モデルそのものを変えたりと色々試してみるのが良さそうです。

終わりに

今日までのmyjlab advent calendar 2022 の同期の記事が、なぜか機械学習に関わるものばかりだったので私も機械学習系の記事を書くことにしました。本題のCNNですら記事を書くまで聞いたことがある程度しか知らなかったので、初心者の記事になってしまいましたが、面白かったのでより深く学んでみようと思います。

何か訂正等あれば、教えていただけると嬉しいです。

参考・出典

10
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?