2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

画像認識 AI のファインチューニングしてみた

Last updated at Posted at 2024-09-22

はじめに

以前に、画像を認識する AI プログラムを試してみました。

画像認識 AI モデルを試してみた #LLM - Qiita

また、文章生成 AI のファインチューニングしてみました。

文章生成 AI のファインチューニングしてみた #LLM - Qiita

画像認識 AI モデルもファインチューニングできるはずですね。
以下の記事を参考にして試してみました。

HuggingFace ブログ : 画像分類用 ViT の微調整 – ClassCat® Blog

画像分類 AI モデルを使ってみる

画像分類 AI モデルを使う

以前に画像認識 AI プログラムを試したときは、LLaVaPaliGemma モデルを、Transformers ライブラリの LlavaForConditionalGenerationPaliGemmaForConditionalGeneration クラスを使って利用しました。

参考:生成AI解説: 条件付き生成(Conditional Generation) – PROMPT.JP

今回は、vit-base-patch16-224-in21k モデルを、Transformers ライブラリの ViTForImageClassification クラスを使って利用します。これは、画像データを渡すとカテゴリのラベルを返すものです。

参考:🔰ViTで実装する画像分類入門 - つくもちブログ 〜Python&AIまとめ〜

実行環境を用意する

AI プログラムの実行環境は、高速な計算するために大きなメモリや GPU を使います。そのため高額なマシンが必要になります。
高機能なマシンを時間利用できるクラウドサービスが用意されています。
これまで Google Colab を使ってきました。便利なサービスですが不便なところもあります。
そこで、GPU 搭載+Python+VS Code の開発環境を、Google Cloud の Compute Engine サービスで用意してみました。

GPU 搭載+Python+VS Code の開発環境を作ってみた #VSCode - Qiita

必要なライブラリをインストールする

Transformers ライブラリを使うのでインストールします。

$ pip install transformers accelerate 

チューニング前のモデルで生成してみる

まず、チューニングする前のモデルを使って画像分類してみます。

import transformers
import torch

# モデルとプロセッサの準備
model = transformers.ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    device_map="auto",
    torch_dtype=torch.float16
)
processor = transformers.ViTImageProcessor.from_pretrained(
    "google/vit-base-patch16-224-in21k",
)

import requests
import PIL

# 画像を参照
url = "https://torch.classcat.com/wp-content/uploads/2022/12/hf-blog-fine-tune-vit_sample400.jpg"
image = PIL.Image.open(requests.get(url, stream=True).raw)

# 推論を実行
inputs = processor(
    images=[image],
    return_tensors="pt"
).to(model.device)

outputs = model(
    **inputs
)

# 結果を出力
idx = outputs.logits.argmax(-1).item()
label = model.config.id2label[idx]

print(label)

こんな結果が表示されます。↓

LABEL_0

ファインチューニングしてみる

ファインチューニングの基本的な流れは以下の通り

①学習データを準備する
②ベースになるモデルを準備する。併せてトークナイザを準備する
③学習を実行する
④学習して作成されたモデルを保存する

必要なライブラリをインストールする

datasetsevaluate ライブラリを使います。

$ pip install datasets
$ pip install evaluate

学習データを用意する

学習データを用意します。Datasets ライブラリを使って Hugging Face のリポジトリから入手することにします。

import datasets

# データセットを用意
datadic = datasets.load_dataset("AI-Lab-Makerere/beans")

# ラベルを用意
labels = datadic['train'].features['labels'].names

読込したデータセットを確認してみます。↓

# データセットを確認
print(datadic)
DatasetDict({
    train: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 1034
    })
    validation: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 133
    })
    test: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 128
    })
})

imagelabels 項目を持っていることが分かります。↑

# ラベルを確認
print(labels)
['angular_leaf_spot', 'bean_rust', 'healthy']

labels 項目は、angular_leaf_spot (角葉スポット)bean_rust (豆さび病)healthy (健康) の選択肢が用意されていることが分かります。↑

hf-blog-fine-tune-vit_leaf-grid.jpg.jpg

豆の葉の画像と状態がセットになっています。↑

モデルを準備する

画像認識するときと同様にモデルを準備します。

# モデルを準備
model = transformers.ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(labels),
    id2label={
        str(i): c for i, c in enumerate(labels)
    },
    label2id={
        c: str(i) for i, c in enumerate(labels)
    }
)

ラベルを表示するために id2labellabel2id を設定しています。↑

特徴抽出器を用意する

文章生成 AI モデルはトークナイザを用意しましたが、ここでは特徴抽出器(extractor)を用意します。

# 特徴抽出器を用意
extractor = transformers.ViTFeatureExtractor.from_pretrained(
    "google/vit-base-patch16-224-in21k"
)

チューニングを実行する

データセットを加工する

チューニングに使いたい列の内容をトークン化するなどデータセットを加工します。

# データセットを加工
prepared_dd = datadic.with_transform(
    lambda batch: {
        # トークン化
        **extractor(
            [x for x in batch['image']],
            return_tensors='pt'
        ),
        # ラベル
        'labels': batch['labels']
    }
)

コレータと評価メトリックを用意する

Trainer クラスが要求するコレータ(照合器)と評価メトリックを用意します。

# コレータを定義
def collate(batch):
    return {
        'pixel_values': torch.stack([
            x['pixel_values'] for x in batch
        ]),
        'labels': torch.tensor([
            x['labels'] for x in batch
        ])
    }

import evaluate
import numpy

# 評価メトリックを定義
metric = evaluate.load("accuracy")

def compute_metrics(p):
    return metric.compute(
        predictions=numpy.argmax(p.predictions, axis=1),
        references=p.label_ids
    )

トレイナを準備する

Trainer クラスを使ってチューニングします。

# トレイナの準備
trainer = transformers.Trainer(
    model=model,
    tokenizer=extractor,
    data_collator=collate,
    compute_metrics=compute_metrics,
    args=transformers.TrainingArguments(
        output_dir="./output",
        num_train_epochs=4,
        per_device_train_batch_size=16,
        remove_unused_columns=False,
    ),
    train_dataset=prepared_dd['train'],
    eval_dataset=prepared_dd['validation'],
)

作成されたモデルを保存する

チューニングして作成されたモデルを保存します。

# トレーニングする
trainer.train()

# 保存する
trainer.save_model("./trained_model")

チューニングしたモデルで生成してみる

チューニングしたモデルを使って画像認識してみます。

# モデルとプロセッサの準備
model = transformers.ViTForImageClassification.from_pretrained(
    "./trained_model",
    device_map="auto",
    torch_dtype=torch.float16
)
processor = transformers.ViTImageProcessor.from_pretrained(
    "google/vit-base-patch16-224-in21k",
)

(以下略)

hf-blog-fine-tune-vit_sample400.jpg

実行結果↓

bean_rust

判定されていることが分かります。↑

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?