1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

このチュートリアルでは、vikit-learnを使用して画像分類器を訓練する方法を学びます。この実践的な例では、猫と犬の画像が含まれているOxfordIIITPetデータセットを使用します。

vikit-learnツールのインストール

pipを使用して、GitHubから直接vikit-learnをダウンロードしてインストールできます:

pip install git+https://github.com/bxt-kk/vikit-learn.git

トレーニングスクリプトの作成

モデルを訓練するためのスクリプトコードを書く必要があります。

1. vikit-learnpytorchから必要なパッケージをインポート

import torch
from torch.utils.data import DataLoader

from vklearn.trainer.trainer import Trainer
from vklearn.trainer.tasks import Classification as Task
from vklearn.models.trimnetclf import TrimNetClf as Model
from vklearn.datasets.oxford_iiit_pet import OxfordIIITPet
  • Trainer: トレーニングパラメータを設定し、トレーニングプロセスを実行するための一般的なトレーニングツールです。
  • Classification: 分類タスクに関連するトレーニングパラメータを指定します。
  • TrimNetClf: vikit-learnに組み込まれた分類器モデルです。
  • OxfordIIITPet: vikit-learnに組み込まれたデータセットツールです。

2. トレーニングデータの準備

dataset_root = '/kaggle/working/OxfordIIITPet'
dataset_type = 'binary-category'

train_transforms, test_transforms = Model.get_transforms()

train_data = OxfordIIITPet(
    dataset_root,
    split='trainval',
    target_types=dataset_type,
	download=False,
    transforms=train_transforms)
test_data = OxfordIIITPet(
    dataset_root,
    split='test',
    target_types=dataset_type,
    transforms=test_transforms)

まず、dataset_rootでデータの場所を指定する必要があります。その後、dataset_type = 'binary-category'でデータのタイプを指定し、猫と犬の二値分類データを意味します。また、データをトレーニングセット(split='trainval')とテストセット(split='test')に分割します。

注意: ローカルディレクトリにデータがない場合は、インターネットからデータをダウンロードするためにdownloadTrueに設定する必要があります。

batch_size = 128

train_loader = DataLoader(
    train_data, batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=4)
test_loader = DataLoader(
    test_data, batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=4)

print(len(train_loader))

pytorchが提供するデータ読み込みツールDataLoaderを使用してデータを読み込みます。ここでは、batch_size = 128に設定しています。

3. モデルとトレーニングタスクの作成

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model  = Model(categories=train_data.bin_classes)
task   = Task(model, device)

TrimNetClfクラスを使用してモデルを作成します。ここで、モデルには分類カテゴリの数とその名前を指定する必要があります。そのために、train_data.bin_classesの値をモデルのcategoriesパラメータとして使用します。次に、モデルオブジェクトmodelとデバイスオブジェクトdeviceを使用してトレーニングタスクオブジェクトを作成します:task = Task(model, device)

4. トレーナーの初期化

trainer = Trainer(
    task,
    output='/kaggle/working/catdog-clf',
    train_loader=train_loader,
    test_loader=test_loader,
    epochs=20,
    lr=1e-3,
    lrf=0.2,
    show_step=50,
    save_epoch=5)

trainer.initialize()

トレーナーのパラメータを設定することで、モデル訓練のためのトレーナーを作成できます。トレーナーオブジェクトを作成した後、trainer.initialize()メソッドを使用して初期化する必要があります。

トレーナーには以下のパラメータを設定します:

  • task: トレーニングタスクを指定します
  • output: チェックポイントとログを保存するためのトレーニングデータの出力パスを設定します
  • train_loader: トレーニングセットのローダーを指定します
  • test_loader: テストセットのローダーを指定します
  • epochs: トレーニングエポックの総数を設定します
  • lr: 学習率を設定します
  • lrf: 学習率の減衰係数を設定します
  • show_step: トレーニング状況を表示する頻度を設定します
  • save_epoch: チェックポイントを保存する頻度を設定します

5. トレーニングタスクの実行

最後に、以下のコードを使用してモデルのトレーニングを開始します:

trainer.fit()

モデルのトレーニングが完了すると、トレーナーに指定した出力パスのlogsサブディレクトリでトレーニングログを見ることができます。





ログに加えて、以下のチェックポイントファイルも表示されます:

- catdog-clf-4.pt
- catdog-clf-9.pt
- catdog-clf-14.pt
- catdog-clf-19.pt
- catdog-clf-best.pt

通常、テストセットの評価指標で最高スコアを得たチェックポイントであるbest.ptで終わるものを使用します。


画像分類器の使用

画像分類器のトレーニングが完了したら、トレーニングされた分類器を使用して画像を自動的に分類できます。

1. まず、必要なパッケージをインポートします

import matplotlib.pyplot as plt
from PIL import Image

from vklearn.models.trimnetclf import TrimNetClf as Model
from vklearn.pipelines.classifier import Classifier as Pipeline

from vklearn.pipelines.classifier import Classifierは、モデルの呼び出しを大幅に簡素化するパイプラインツールClassifierをインポートします。

2. モデルクラスとチェックポイントファイルを指定して分類器を作成

pipeline = Pipeline.load_from_state(
    Model, '???/catdog-clf-best.pt')

注意: '???/catdog-clf-best.pt'をコンピュータ上のチェックポイントファイルへの実際のパスに置き換えることを忘れないでください。

3. モデルを開いて分類予測を行い、結果を視覚化

一連の準備が完了したら、次のコードを使用して分類を行うことができます:

img = Image.open('??your image path??')
result = pipeline(img)
fig = plt.figure()
pipeline.plot_result(img, result, fig)
plt.show()

上記のコードを使用して画像を開き(img = Image.open('??your image path??'))、分類予測(result = pipeline(img))を行い、予測結果を視覚化します:

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?