7
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorchで画像分類モデルのトレーニングを行うまでの基本的な流れ

Last updated at Posted at 2025-04-16

業務でAIに触り始めてある程度経ったので、AIを扱う流れをPyTorchでの画像分類モデルのトレーニングを例に簡単にまとめます。

データセットの準備

データ準備

配布されているデータセットではなく独自に用意したデータを扱う場合、最も面倒で大変な工程だと思います。
画像分類の場合基本的に以下の流れになります。

1. 画像収集と選別

集める仕組みが既に出来ている場合はある程度楽かもしれませんが、何もない状態から始める場合、必要な枚数によってはとても辛い工程になると思います。
Dataset Searchなどで配布データセットを探す、業者に頼るなどして避けられるのであれば避けたいところです。
この時点で明らかにまともに撮れていないものなどは除いておきます。

2. 分類対象部分の切り抜き

分類の対象以外が写り込んでいるとモデルの精度が上がり辛いため、対象部分を切り抜きます。
手作業で行うのは大変なので、物体検知モデルなどを探して自動化した方がいいです。モデルを使って行う場合には正確でない場合も多いため、処理後にきちんと対象が写っている部分を切り抜けているかを確認しておきましょう。

3. アノテーション

収集した画像に、そこに写っているものが何かを示すラベルを付けます。自分は使ったことがありませんが、ラベリングを自動化するようなAWS SageMaker Ground Truthなどのサービスもあるようです。
ここでラベルが間違っているとまともなモデルは出来ないので、ラベルが間違っていないかの確認をしっかり行います。

4. データの偏り確認とトレーニング・検証・テスト用に分割

データセットを以下の三つに分けます。

  • トレーニング用:モデルのパラメータ学習に使用
  • 検証用:各エポック毎にトレーニングでは使っていないデータでの精度確認に使用
  • テスト用:学習完了後の精度確認に使用

これはホールドアウト法と呼ばれる方法でモデル評価を行う場合の分け方で、十分に画像があれば8:1:1くらいを目安に分けるのが一般的です。
また、トレーニング用データセットの各クラスのデータ量が明らかに偏っていると、推論結果がデータ数の多いクラスに偏る可能性があるので、少ないクラスのデータを追加する、できない場合には多いクラスからデータを減らすことで、各クラスのデータ数を大体同じ数にします。一般的かは分かりませんが、私はデータ数の差が1~2割程度になるように調整するようにしています。一応ここで調整せずとも、後のDataLoadersのsampler引数を使って数を調整することもできます。

5. 最終確認

ちゃんと対象物が写っているか、切り抜きはできているか、ラベルは正確か、データに偏りはないかなど、できたデータセットの中身の最終確認を行います。

Datasets

PyTorchではデータセットに関するコードを簡潔にできるようDatasetsDataLoadersがあり、Datasetsでデータの読み込みや前処理などを行い、後述するDataLoadersDatasetsからどのようにデータを取り出すかをコントロールします。
Datasets & DataLoaders

自作データセット用のDatasets

PyTorchで用意されているデータセットではなく、自分で用意した(PyTorch外部で配布されているものを含む)ものを扱う場合、ディレクトリ構造を自由に弄れるのであればImageFolderを使うのが一番簡単だと思います。ImageFolderを使う場合、以下のようなディレクトリ構成で画像を保存します。

image_root
├── label1
│    ├── image1
│    ├── image2
│   ...
└── label2
│   ├── image1
│   ├── image2
│   ...
...

この構造にできるのであれば、コードの方はImageFolderクラスのroot引数にデータセットのルートパスを与えるだけでデータセットを読み込めます。

読み込むデータセットの指定方法を変えたい場合などにはDatasetクラスを継承して__len__()__getitem__()をオーバーライドしてカスタムしていきます。
Writing Custom Datasets, DataLoaders and Transforms

transforms(前処理等)

Datasetsの引数としてtorchvisionの各種transformsを渡し、これによって画像のサイズ調整やスケーリング、PyTorchのTensor型への変換などの前処理、必要に合わせてデータ拡張なども行うことができます。複雑なデータ拡張を行いたい場合には、Albumentationsなどのライブラリを使うこともありますが、基本的なものはtorchvisionのtransformsでも行えます。
以下は基本的に行うことになる前処理の簡単な説明です。

スケーリング

学習効率及び精度を向上させるために、画像の各ピクセルが取る値の範囲を正規化や標準化などで調節します。画像分類ではImageNetデータセットを使って事前学習されたモデルが多い印象で、平均:(0.485, 0.456, 0.406)、標準偏差:(0.229, 0.224, 0.225)といった統計値を使った標準化が使われているケースをよく見かけますが、事前学習に使われているデータセットの統計値での標準化を行うのが基本だと思われます。この標準化を行うtransformsは以下のように記述します。

transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
画像サイズ調整

この工程もアスペクト比を維持するか、サイズをどうするかなど考える必要がありますが、モデルの特性や、事前学習で使用されているデータ、実際にどのサイズで使いたいかなど、都合に合わせて調整します。以下はサイズ:(224, 224)でアスペクト比を維持しない単純なリサイズ例です。

transforms.Resize(size=(224, 224))
Tensorへの変換

PyTorchで扱えるようにTensor型に変換します。torchvisionのtransformsで行うには、ToTensorが使えます。これはTensor型に変換すると共に、[0, 1]へのスケーリングも同時に行います。

transforms.ToTensor()

ここまで紹介した3つの処理を適用するImageFolderは以下のように書けます。

from torchvision import transforms
from torchvision.datasets import  ImageFolder

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                transforms.Resize(size=(224, 224))
            ])

dataset = ImageFolder(root="image_root", transform=transform)

DataLoaders

データセットからどのデータをどのような順番で、何枚ずつ取り出すかなどを指定できます。以下の引数をよく使います。

dataset (Dataset):

先述のように作成したデータセットを与えます。

batch_size (int, optional):

データセットから何枚ずつ取り出すか(トレーニングの1ステップで使うデータ数)を設定します。例えば500枚のデータをトレーニングに使う場合、batch_size64に設定すると、1エポックあたり9ステップ踏むことになり、9回パラメーター更新が行われるといった感じになります(ミニバッチ勾配降下法)。基本的にメモリが許す限り大きい値を指定します。

shuffle (bool, optional):

データをシャッフルし並び順を変更するかを設定します。これをFalseにすると、毎エポック各バッチに含まれるデータが同じで、バッチの順番も同じデータで学習を行うことになるので、特に理由がなければTrueにしておきます。

sampler (Sampler or Iterable, optional):

同じデータセットをトレーニング用と検証用で使いたい場合や、各クラスに偏りがある場合など、使うデータ数やデータ順をコントロールしたい時などに設定します。shuffleとは併用できません。
Sampler

num_workers (int, optional):

データを並列で読み込む場合にいくつのプロセスを使うかを指定します。

multiprocessing_context (str or multiprocessing.context.BaseContext, optional):

データロード時に使用するマルチプロセッシングのコンテキストを指定するために使用されるそうです。他のOSではよくわかりませんが、macOSで扱う場合にはforkforkserverを指定することが推奨されているようです。Data loader multiprocessing slow on macOS

これらのパラメータを使用し、例えば以下のようにデータローダーを記述できます。

train_loader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=True,
    num_workers=2,
    multiprocessing_context="fork"
)

モデル選定

モデルはタスクの種類、サーバー上かモバイル上かなど動作させる端末のスペック、推論速度などを考えて決定します。探し方はいろいろありますが、自分は画像分類 モバイル SOTAのように検索することが多いです。
あとはこちらのPapers With Codeというサイトでは、各タスク毎にSOTAモデルがまとまっているのに加え、大抵そのモデルを実装してあるリポジトリも載っていて便利です。

トレーニング

簡単なトレーニングを行うだけであれば、どのモデルでも以下に書くような基本部分は同じコードになる、と思っています。

# 損失関数と最適化手法の定義
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# エポック数の定義
epochs = 100

for epoch in range(epochs):
    # トレーニング
    train_loss = 0.0
    model.train()
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        # 勾配の初期化
        optimizer.zero_grad()
        # 順伝播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # 逆伝播
        loss.backward()
        # パラメータ更新
        optimizer.step()
        # ロスの累積
        train_loss += loss.item()

    # エポックごとのトレーニングの平均ロスの表示
    print(f'Epoch {epoch+1}, Loss: {train_loss / (i+1)}')
    
    # 検証
    val_loss = 0.0
    model.eval()
    with torch.no_grad(): # 勾配計算無効化
        for i, data in enumerate(val_loader, 0):
            inputs, labels = data
            output = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    # エポックごとの検証の平均ロスの表示
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss / (i+1)}')

損失関数や最適化関数にもいろいろな種類がありますが、画像分類であればCrossEntropyLossAdamの組み合わせでいいと思います。興味がある方はいろいろと試してみて下さい。
また、上記のコードでは単純な損失関数の数値を出力していますが、正解率や適合率、再現率などの評価指標を使って性能を評価するのが一般的です。どのような指標があるか調べて、用途に合ったものを使用してください。

おまけ

最後に必須ではありませんが、性能を高めたりするために行えることをいくつか挙げておきます。

ハイパーパラメーターチューニング

学習率やバッチサイズ、エポック数などのハイパーパラメーターを変えていろいろなバリエーションでトレーニングを行い、より精度の高いモデルを作ることを目的に行います。場合によっては最適化関数や損失関数なども入れ替えて性能が良くなるものを探っていきます。手動、グリッドサーチ、ランダムサーチなど、やり方もいろいろあるので、興味のある方は調べてみて下さい。Optunaのようなハイパーパラメーターチューニング用のフレームワークも存在します。

早期停止(アーリーストッピング)

過学習を防ぐため、検証データセットの目的関数の値が上昇する直前のエポック数を採用します。検証にかかる時間が短いようであれば毎エポック、時間がかかる場合には指定したエポック毎に検証データセットの目的関数の値を確認して、上昇しているようであればトレーニングを止めることで、汎化性能を維持したモデルを作ることができます。

モデル圧縮

トレーニング後にモデルのサイズを小さくするために行います。代表的なものとして、量子化や枝刈り、蒸留などといったものがあります。

7
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?