5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

100行でわかるpytorch

Last updated at Posted at 2024-11-06

AIって難しそうで手が出ない、あるいはやろうとしたけどpytorchチュートリアルでよくわからなくて死んだ方へ。

pytorchの最低限が100行でわかる記事です。
今回はアニメ顔とリアル顔を識別するモデルをさくっと書いていきます。
理論とかはすっ飛ばして実装にだけ焦点を当てて解説していきます。

注意

pytorchから直接取ってきている関数等の細かい仕様は公式ページを参照しましょう。クラスや関数の説明はわかりやすい(はず)。

環境

  • python : 3.10.14
  • pytorch : 2.4.1+cu121
  • opencv : 4.10.0
  • numpy : 2.1.1

目次

  1. 全体のコード
  2. モデル
  3. データセット
  4. 学習
  5. テスト
  6. 結果
  7. まとめ

全体のコード

all.py
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optimizers
import glob
import numpy as np
import cv2

class Model(nn.Module):
    def __init__(self, device="cuda:0"):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.linear1 = nn.Linear(in_features=148*148*3, out_features=50)
        self.linear2 = nn.Linear(in_features=50, out_features=1)
        self.act = nn.Sigmoid()
        self.to(device)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear1(x)
        x = nn.functional.normalize(x)
        x = self.linear2(x)
        return self.act(x)

class AnimeRealDataset(Dataset):
    def __init__(self):
        self.data = []
        path = "./anime_real_classification/"
        anime_files = glob.glob(path + 'anime_face_300/*')
        for file in anime_files:
            self.data.append((cv2.imread(file), np.array([0])))
        real_files = glob.glob(path + 'human_face_300/*')
        for file in real_files:
            self.data.append((cv2.imread(file), np.array([1])))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx][0], self.data[idx][1]

def train(model, loader, optimizer, criterion, device="cuda:0"):
    model.train()
    num_epochs = 10
    for epoch in range(num_epochs):
        for input_data, label in loader:
            input_data = input_data.transpose(1, 3).to(device).to(torch.float32)
            label = label.to(torch.float32).to(device)
            optimizer.zero_grad()
            pred = model(input_data)
            loss = criterion(pred, label)
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    
def test(model, loader, device="cuda:0"):
    model.eval()
    sum = 0
    correct = 0
    for input_data, label in loader:
        input_data = input_data.transpose(1, 3).to(torch.float32).to(device)
        label = label.to(torch.float32).to(device)
        with torch.no_grad():
            pred = model(input_data)
            print(f"label : {label}, pred : {pred}")
            pred_label = 1 if pred.item() >= 0.5 else 0
            label = int(label)
            if label == pred_label:
                correct += 1
            sum += 1
    print(f"correct / sum : {correct} / {sum}")

def main():
    dataset = AnimeRealDataset()
    trainloader = DataLoader(dataset=dataset, batch_size=8, shuffle=True)
    testloader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
    model = Model()
    criterion = nn.BCELoss()
    optim = optimizers.Adam(model.parameters(), lr=0.01)
    train(model, trainloader, optim, criterion)
    test(model, testloader)

if __name__ == "__main__":
    main()

実は87行です。

モデル

とにかく簡単に書きたいので、畳み込み1層、プーリング1層、線形層2層の小さなモデルです。

model.py
class Model(nn.Module):
    #input = (3, 300, 300)
    def __init__(self, device="cuda:0"):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.linear1 = nn.Linear(in_features=148*148*3, out_features=50)
        self.linear2 = nn.Linear(in_features=50, out_features=1)
        self.act = nn.Sigmoid()
        self.to(device)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear1(x)
        x = nn.functional.normalize(x)
        x = self.linear2(x)
        return self.act(x)

pytorchでモデルを書くときは、

  1. nn.Moduleを親クラスにして
  2. コンストラクタ内で親クラスのコンストラクタを呼んで(super().__init__()のところ)
  3. forward関数を作る

という3つのステップを踏めばできます。

forward関数はモデルにデータを入力するときに呼ばれる関数です。
今回の入力の次元は(バッチ,色,縦,横)の4次元です。
今回に限らずモデルに入ってくるデータの1次元目はバッチを表しているので、flattenやreshapeなどでデータの形をいじるときは基本それより後ろの次元で行います。

データセット

2つのディレクトリに入っている画像を、pytorchで使える形に直す部分です。

dataset.py
class AnimeRealDataset(Dataset):
    def __init__(self):
        self.data = []
        path = "./anime_real_classification/"
        anime_files = glob.glob(path + 'anime_face_300/*')
        for file in anime_files:
            self.data.append((cv2.imread(file), np.array([0])))
        real_files = glob.glob(path + 'human_face_300/*')
        for file in real_files:
            self.data.append((cv2.imread(file), np.array([1])))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx][0], self.data[idx][1]

データセットを作るには

  1. Datasetを親クラスにして
  2. __len__関数と__getitem__関数を作る

という2つのステップを踏めばできます。

__len__関数はデータの合計数を返す関数で、__getitem__関数は整数を1つ受け取って、それに対応したデータを返す関数です。

__getitem__関数の返り値は後の訓練ループ中でデータを取り出すときに使うものなので、自分が使いやすい形にしてOKです。今回は画像データとラベルにしてあります。

学習

main関数で下準備、train関数で学習を行います。

main.py
def main():
    dataset = AnimeRealDataset()
    trainloader = DataLoader(dataset=dataset, batch_size=8, shuffle=True)
    testloader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
    model = Model()
    criterion = nn.BCELoss()
    optim = optimizers.Adam(model.parameters(), lr=0.01)
    train(model, trainloader, optim, criterion)
    test(model, testloader)
train.py
def train(model, loader, optimizer, criterion, device="cuda:0"):
    model.train()
    num_epochs = 10
    for epoch in range(num_epochs):
        for input_data, label in loader:
            input_data = input_data.transpose(1, 3).to(device).to(torch.float32)
            label = label.to(torch.float32).to(device)
            optimizer.zero_grad()
            pred = model(input_data)
            loss = criterion(pred, label)
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

main関数

main関数で準備するものは

  1. データローダー
  2. モデル
  3. ロス関数
  4. オプティマイザ

の4つです。

データローダーはさっき作ったデータセットを学習で使いやすい形にしてくれるやつです。バッチサイズごとにデータをまとめてくれます。
テストするときもデータローダーを使うので、予め用意しておきます。基本的には訓練データとテストデータは分けますが、今回は適当なので同じものです。

モデルはさっき作ったやつをそのまま持ってきます。

ロス関数はBCEというものを採用しています。使う値が0~1のときだけ使える関数です。

オプティマイザはAdamです。理由はなんとなく。
Adamに限らずオプティマイザを設定するときは、そのオプティマイザでどのパラメータを動かしますか、というのを入力する必要があります。「このモデルのパラメータ全部やって」というときは「モデルが入った変数の名前.parameters()」をオプティマイザに渡してあげればOKです。

train関数

学習をする関数です。ここでやることは

  1. model.train()をして
  2. エポックの数だけ学習ループを回す

model.trainの呼び出しは今回使ってないような、訓練とテストで動きが変わる層の動きを指定できるのでやっておいた方がいいです。お約束的な感覚かも。

2.の学習ループの中身としては

  1. データローダーからバッチごとのデータを取り出す
  2. データ、ラベルをモデルが扱えるように調整
  3. 勾配を初期化して
  4. モデルにぶち込む
  5. 出てきた結果と答えをロス関数で照らし合わせて
  6. 勾配を計算
  7. パラメータに適用

となります。長いですね。

データローダーからデータを取り出すのはfor文でできます。
for文で作る変数(今回だとinput_dataとlabel)が__getitem__の出力と対応しています。

さっき作ったデータセットのままだと、データの型がintでモデルが食べてくれないのでfloatに変換します。
今回に限らず、入力とラベルは基本的にfloat型です。
また、opencvで画像を読み込むと色が最後の次元に来てしまい、(300, 300, 3)の形でモデルが想定しているものではないので、transpose関数で入れ替えてあげてます。

勾配の初期化は、バッチごとに行わないと過去の学習の影響を受けてしまうのでやっておきましょう。

ここまでで下準備ができたので、モデルにデータを食べてもらいます。
「モデルが入った変数名(input_data)」という風にすると自動的にforward関数にinput_dataを入れてくれます。

モデルの出力とラベルをロス関数に入力すると、その2つを比較してくれます。
backward関数を呼び出すことで、そのロスを計算するときに使ったパラメータの勾配を計算してくれます。

最後に、オプティマイザのstep関数を呼び出してパラメータを更新してもらいます。

テスト

どれだけ学習できたかを確認するフェーズです。

test.py
def test(model, loader, device="cuda:0"):
    model.eval()
    sum = 0
    correct = 0
    for input_data, label in loader:
        input_data = input_data.transpose(1, 3).to(torch.float32).to(device)
        label = label.to(torch.float32).to(device)
        with torch.no_grad():
            pred = model(input_data)
            print(f"label : {label}, pred : {pred}")
            pred_label = 1 if pred.item() >= 0.5 else 0
            label = int(label)
            if label == pred_label:
                correct += 1
            sum += 1
    print(f"correct / sum : {correct} / {sum}")

テスト内でやることは

  1. model.eval()をして
  2. データローダーからデータを取り出して
  3. 入力データとラベルをいい感じに変換して
  4. モデルに入力データを食べさせて
  5. 出力がラベルと合っているかを確認する

という流れになります。

モデルを評価モードにするのは、訓練のときと同じような理由です。お約束の認識でOK。

データローダーからデータを取り出したり、データとラベルを使いやすい形にするのも訓練のときと同じ。

モデルの出力を自分で評価するのが、train関数との大きな違いです。
今回は2値分類なので、ラベルは0か1になっています。
しかし、モデルの出力は最後にシグモイド関数に通したことで0~1の間の数になっています。
そこで今回は、0.5未満だったら0、0.5以上だったら1って言いたかったんだろうなぁと決めてしまいます。

後はこの評価基準に従って、どのくらい正答できたかを見てあげればおしまいです。

結果

テストの出力がドバっと出るようにしてしまったので、訓練時の出力とテストの一部だけ。

output.
Epoch [1/10], Loss: 0.7755
Epoch [2/10], Loss: 0.4105
Epoch [3/10], Loss: 0.3286
Epoch [4/10], Loss: 0.0246
Epoch [5/10], Loss: 0.0176
Epoch [6/10], Loss: 0.0171
Epoch [7/10], Loss: 0.0927
Epoch [8/10], Loss: 0.0356
Epoch [9/10], Loss: 0.0064
Epoch [10/10], Loss: 0.0045


.
.
.
label : tensor([[1.]], device='cuda:0'), pred : tensor([[0.9935]], device='cuda:0')
label : tensor([[0.]], device='cuda:0'), pred : tensor([[0.0632]], device='cuda:0')
label : tensor([[0.]], device='cuda:0'), pred : tensor([[0.0913]], device='cuda:0')
label : tensor([[1.]], device='cuda:0'), pred : tensor([[0.9867]], device='cuda:0')
label : tensor([[0.]], device='cuda:0'), pred : tensor([[0.0055]], device='cuda:0')
correct / sum : 588 / 593

学習が進むごとにロスが下がって、テストではちゃんと予測値(pred)がラベルに近づいていますね。

まとめ

今回は100行でpytorchの基礎を書いてみましたが、どうだったでしょうか。意外と100行以内でも基礎を網羅できるので、pytorch初めて触るよって人には割と良い教材なんじゃないかと自己満足しています。大きいモデルを扱う人も、根本の訓練ループやモデルの書き方は変わらないと思うので、上のコードを自己流にアレンジしてもらえれば嬉しいです。

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?