22
25

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

google colabでVision TransformerのFine Tuningをやってみた

Last updated at Posted at 2021-01-13

はじめに

Deep Learningの画像認識の分野でVision Transformer(ViT)という、今注目を浴びているモデルがあります。

今回google colabでgoogle-researchによるVision Transformerの実装のfine tuningを行ってみたので、その内容を備忘録を兼ねてまとめてみたのが本記事になります。

Vision Transformerとは

29 Sep 2020に発表された「An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale」という論文にて提唱された画像認識のモデルです。

このモデルの最大特徴としては従来の画像認識で利用されてきた畳み込みネットワークを用いず、その代わりに自然言語処理の分野で目覚ましい成果を上げているAttentionと呼ばれるネットワークを用いた画像分類のモデルです。

また、このモデルは計算コストが低いうえに精度がよい(SoTAを上回る性能を約1/15の計算コストで得られる)と言われています。

詳細なモデルの解説に関しては「画像認識の大革命。AI界で話題爆発中の「Vision Transformer」を解説!」という記事が分かりやすかったです。

google researchの実装

このViTですが、googleによって実装が公開されています(https://github.com/google-research/vision_transformer)。

そして、READMEによると以下のようなことが述べられていました。

また、google colabのチュートリアルではCifer10を用いたfine-tuneやモデルの推論処理の実行方法などが説明されていました。

 さて、この実装を用いたfine-tuneですが、READMEやgoogle colabのチュートリアルではTensorFlow Datasetsにもともと入っているデータセット(Cifer10など)に対してのfine-tuneの方法は詳しく書いてあるのですが自前で用意したデータセットに対してのfine-tuneの方法はわかりやすい実現方法の記載がありませんでした。そこで、私が行った自前で用意したデータセットに対してのfine-tuneを行う方法についてこれ以降解説していきます。

利用したデータセット(Signateの[【練習問題】画像ラベリング(20種類)])

今回利用したデータセットは、Signateの「【練習問題】画像ラベリング(20種類)」というコンペ(練習問題)のデータセットです。

今回利用するデータは以下の3つです。

  • train.zip
    学習用データ。中身はpng画像(例:train_00000.png)

  • train_master.tsv
    学習用画像データとラベルIDの対応表。ラベルは0~19

    file_name	label_id
    train_00000.png	11
    train_00001.png	15
    ・・・省略・・・
    
  • test.zip
    評価用画像データ。中身はpng画像(例:test_00000.png)

コード実施前の事前準備

これから解説するコードの実行前に以下の準備を事前に行ってください。
[google colabで行う準備]

  • train_master.tsvのアップロード
  • ランタイムのタイプを「GPU」に設定する。

[google driveで行う準備]

  • train.zipとtest.zipをgoogle
    driveに格納する。

コードの解説

ようやく、コードの解説に入ります。また、本コードはgoogle colabのチュートリアルを参考に自前で用意したデータセットに対してのfine-tuneできるよう修正をしたものになります。

リポジトリのクローンと事前学習モデルのダウンロード。

以下をコードでリポジトリのクローンと事前学習モデルのダウンロードを行います。

(注:このコードをcolab上でなくローカル環境で実施しようとした場合gsutilが使えない可能性があります。その場合はstorage bucketからダウンロードするとよいと思います)

# リポジトリをクローンして、最新の変更をプルします。
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
!cd vision_transformer && git pull

!pip install -qr vision_transformer/vit_jax/requirements.txt


# 事前に訓練されたモデルをダウンロードします。
model = 'ViT-B_16'
![ -e "$model".npz ] || gsutil cp gs://vit_models/imagenet21k/"$model".npz .

必要なモジュールのインポート

インポートするモジュールを見ると、今回の実装ではjaxというDeep Learningのフレームワークが使われていることがわかります。

また、あとで見ていきますがtensorflowはデータの前処理に用いていきます。

import flax
import jax
from matplotlib import pyplot as plt
import numpy as np
import tqdm
import sys
if './vision_transformer' not in sys.path:
  sys.path.append('./vision_transformer')
from vit_jax import checkpoint
from vit_jax import hyper
from vit_jax import input_pipeline
from vit_jax import logging
from vit_jax import models
from vit_jax import momentum_clip
from vit_jax import train
import pathlib
import tensorflow as tf
import csv
import pandas as pd

# ログの設定
logger = logging.setup_logger('./logs')

google driveのマウントとデータのunzip

今回zipファイルのサイズが大きく、zipファイルをcolabに直接アップロードしようとするとアップロード中にランタイムの切断が起こるなどしてうまくアップロードできないことがあるのでgoogle driveにzipをアップロードしてそれをマウントします。また、以下の操作でzipファイルはカレントディレクトリの直下に展開されます。

from google.colab import drive
drive.mount('/content/drive')

! unzip <train.zipのパス>
! unzip <test.zipのパス>

(注:このコードをcolab上でなくローカル環境で実施する場合zipファイルを実行するpythonファイルと同じディレクトリに解凍することで、以降の同様のコードで実行できると思います。)

tensorflow datasetsに訓練用画像をロードさせる

訓練用画像をtensorflow datasetsに読み込ませてパイプラインを構築してデータを吐き出すジェネレータを作成します。

また、ここで画像のリサイズや正規化などの前処理を行います。参考

data_root = pathlib.Path('./train')
all_image_paths = list(sorted(data_root.glob('*')))
all_image_paths = [str(path) for path in all_image_paths]
image_count = len(all_image_paths)
AUTOTUNE = tf.data.experimental.AUTOTUNE

def preprocess_image(image):
  image = tf.image.decode_png(image, channels=3)
  # 画像のリサイズ
  image = tf.image.resize(image, [384, 384])
  # 画像の正規化
  image /= 255.0  # normalize to [0,1] range
  return image

def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

# ロードしたデータの確認
#plt.figure(figsize=(8,8))
#for n,image in enumerate(image_ds.take(4)):
#  plt.subplot(2,2,n+1)
#  plt.imshow(image)
#  plt.grid(False)
#  plt.xticks([])
#  plt.yticks([])
#  plt.show()

all_image_labels = []
with open('train_master.tsv', encoding='utf-8', newline='') as f:
    for cols in csv.reader(f, delimiter='\t'):
        all_image_labels.append(cols[1])

all_image_labels = all_image_labels[1:]
all_image_labels_i = [int(s) for s in all_image_labels]

n_labels = len(np.unique(all_image_labels_i))  # 分類クラスの数 = 20
all_image_labels_o = np.eye(n_labels)[all_image_labels_i]           # one hot表現に変換

label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels_o, tf.int64))
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

ミニバッチの作成

バッチサイズを指定しミニバッチを作成します。

# GPUだと256にする(減らす)
BATCH_SIZE = 256
ds = image_label_ds.batch(BATCH_SIZE)
# `prefetch`を使うことで、モデルの訓練中にバックグラウンドでデータセットがバッチを取得できます。
ds = ds.prefetch(buffer_size=AUTOTUNE)
batch = next(iter(ds.as_numpy_iterator()))

(注:ここで生成されるbatchテンソルの形状がチュートリアルのものと異なることに注意してください)

モデル定義の読み込みとランダムパラメータの初期化

モデル定義の読み込みとランダムパラメータの初期化を行います。

VisionTransformer = models.KNOWN_MODELS[model].partial(num_classes=20)
_, params = VisionTransformer.init_by_shape(
    jax.random.PRNGKey(0),
    # 初期化のためにバッチの "num_local_devices "次元を破棄する。
    [(batch[0].shape, 'float32')])

(注:batchテンソルの形状がことなるため配列の要素の指定の仕方が異なることに注意して下さい)

事前学習したパラメーターの読み込み

事前学習したパラメーターを読み込み変数に格納します。

# これには、実際の事前学習済みモデルの結果をロードするだけでなく、最終レイヤーを変更したり、位置的な埋め込みの# サイズを変更するなど、パラメータを少し変更することも含まれます。
params = checkpoint.load_pretrained(
    pretrained_path=f'{model}.npz',
    init_params=params,
    model_config=models.CONFIGS[model],
    logger=logger,
)

# ここから下は何をやっているかよく分からなかった。おそらく、host memoryにあるデータをTPUで分散処理するため
# にレプリカを作成しているような雰囲気
params_repl = flax.jax_utils.replicate(params)
print('params.cls:', type(params['cls']).__name__, params['cls'].shape)
print('params_repl.cls:', type(params_repl['cls']).__name__, params_repl['cls'].shape)

vit_apply_repl = jax.pmap(VisionTransformer.call)

ハイパーパラメーターの設定

total_steps = 100
warmup_steps = 5
decay_type = 'cosine'
grad_norm_clip = 1
accum_steps = 64
base_lr = 0.03

Optimizerの設定など

fine-tuneのインスタンスの生成やOptimizerの設定などを行う。

# 詳しい内容はtarin.pyのmake_update_fnメソッドの中を見ること
update_fn_repl = train.make_update_fn(VisionTransformer.call, accum_steps)
# 今回はmomentum optimizerを使用する。
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)

lr_fn = hyper.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps)
update_rngs = jax.random.split(jax.random.PRNGKey(0), jax.local_device_count())

fine-tuneを実行しパラメーターを更新する。

for step, batch, lr_repl in zip(
    tqdm.notebook.trange(1, total_steps + 1),
    ds.as_numpy_iterator(),
    lr_iter
):
  batch_dict = {
      'image':np.array([batch[0]]),
      'label':np.array([batch[1]])
  }
  opt_repl, loss_repl, update_rngs = update_fn_repl(
      opt_repl, lr_repl, batch_dict, update_rngs)

ここで、update_fn_replメソッドの第三引数はキーとして'image'、'label'を値としてnumpy配列を持つようなdictを要求するので、バッチをその形に修正します。

また、訓練後のパラメーターはopt_replに格納されるようです。

テストデータのロード

訓練データ同様にテスト画像をtensorflow datasetsに読み込ませてパイプラインを構築してデータを吐き出すジェネレータを作成します。

​```
vit_apply_repl = jax.pmap(VisionTransformer.call)
data_test = pathlib.Path('./test')
test_image_paths = list(sorted(data_test.glob('*')))
test_image_paths = [str(path) for path in test_image_paths]
image_count = len(test_image_paths)
path_ds_test = tf.data.Dataset.from_tensor_slices(test_image_paths)
image_ds_test = path_ds_test.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

#訓練データの確認
#plt.figure(figsize=(8,8))
#for n,image in enumerate(image_ds_test.take(4)):
#  plt.subplot(2,2,n+1)
#  plt.imshow(image)
#  plt.grid(False)
#  plt.xticks([])
#  plt.yticks([])
#  plt.show()

BATCH_SIZE = 100

ds_test = image_ds_test.batch(BATCH_SIZE)
# `prefetch`を使うことで、モデルの訓練中にバックグラウンドでデータセットがバッチを取得できます。
ds_test = ds_test.prefetch(buffer_size=AUTOTUNE)

テストデータの画像分類

テストデータに対して画像分類を行います。処理はチュートリアルのget_accuracyを参考にし、vit_apply_replの引数に関して要求する形に合うようデータを修正します。

また、vit_apply_replは返り値として各ラベルの評価画像に対するスコアを返します。

今回は、各ラベルの評価画像に対する確率を出したいためvit_apply_replのスコアをsoftmax関数に入れて確率に変換します。

steps = image_count // BATCH_SIZE
for i, batch in zip(tqdm.notebook.trange(steps), ds_test.as_numpy_iterator()):
    if i == 0:
        predicted = vit_apply_repl(opt_repl.target, batch[np.newaxis, :])
        predicted = flax.nn.softmax(predicted)
        predict_prob = predicted[0]
    else:
        predicted = vit_apply_repl(opt_repl.target, batch[np.newaxis, :])
        predicted = flax.nn.softmax(predicted)
        predict_prob = np.append(predict_prob, predicted[0], axis=0)

結果をCSVに格納する

評価画像に対する各ラベルの確率をCSVに書き込み保存します。

submit_file = pd.DataFrame(predict_prob)
submit_file.index = [path.replace("test/", "") for path in test_image_paths]
submit_file.to_csv("result_file.csv", header=False)

以上でsignateの訓練データのでfine-tuneを実行し、それを使ってテストデータの画像を画像分類をすることができました。

結果

無題.png
リーダーボードによるとベンチマークのスコアが" 1.2764459"なのでそれよりもよい結果が得られました。

そして、このスコアでの順位が上位30位程度だったのでViTのすごさを実感しました。

また、今回は事前学習のモデルとして「ViT-B_16」を使用しましたが、他のモデルを使用することで更なる精度向上の可能性があるため今度試してみようと思いました。

最後に

話題になっていたViTを試してみたい思い挑戦してみたのが今回の試みですが,とりあえず動かすだけであれば,半日くらいで実装できたので,これから使ってみたい人に,この記事が少しでも役に立ったらいいなと思っています。次の取り組みとしてはfine-tuneしたモデルをnpzに保存する方法などを試していきたいと考えています。

最後に、今回のsigeateデータでViTをfine-tuneするコードの全体を記載しこの記事を終わろうと思います。

最後まで記事を読んでいただきありがとうございました.

# -*- coding: utf-8 -*-
# リポジトリをクローンして、最新の変更をプルします。
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
!cd vision_transformer && git pull

!pip install -qr vision_transformer/vit_jax/requirements.txt


# 事前に訓練されたモデルをダウンロードします。
model = 'ViT-B_16'
![ -e "$model".npz ] || gsutil cp gs://vit_models/imagenet21k/"$model".npz .

import flax
import jax
from matplotlib import pyplot as plt
import numpy as np
import tqdm
import sys
if './vision_transformer' not in sys.path:
  sys.path.append('./vision_transformer')
from vit_jax import checkpoint
from vit_jax import hyper
from vit_jax import input_pipeline
from vit_jax import logging
from vit_jax import models
from vit_jax import momentum_clip
from vit_jax import train
import pathlib
import tensorflow as tf
import csv
import pandas as pd
# %load_ext autoreload
# %autoreload 2
# ログの設定
logger = logging.setup_logger('./logs')

from google.colab import drive
drive.mount('/content/drive')

! unzip <train.zipのパス>
! unzip <test.zipのパス>

data_root = pathlib.Path('./train')

all_image_paths = list(sorted(data_root.glob('*')))
all_image_paths = [str(path) for path in all_image_paths]

image_count = len(all_image_paths)
AUTOTUNE = tf.data.experimental.AUTOTUNE

def preprocess_image(image):
  image = tf.image.decode_png(image, channels=3)
# 画像のリサイズ
  image = tf.image.resize(image, [384, 384])
# 画像の正規化
  image /= 255.0  # normalize to [0,1] range

  return image

def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

# ロードしたデータの確認
#plt.figure(figsize=(8,8))
#for n,image in enumerate(image_ds.take(4)):
#  plt.subplot(2,2,n+1)
#  plt.imshow(image)
#  plt.grid(False)
#  plt.xticks([])
#  plt.yticks([])
#  plt.show()

all_image_labels = []
with open('train_master.tsv', encoding='utf-8', newline='') as f:
    for cols in csv.reader(f, delimiter='\t'):
        all_image_labels.append(cols[1])

all_image_labels = all_image_labels[1:]
all_image_labels_i = [int(s) for s in all_image_labels]

n_labels = len(np.unique(all_image_labels_i))  # 分類クラスの数 = 20
all_image_labels_o = np.eye(n_labels)[all_image_labels_i]           # one hot表現に変換

label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels_o, tf.int64))
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

# GPUだと256にする(減らす)
BATCH_SIZE = 256
ds = image_label_ds.batch(BATCH_SIZE)
# `prefetch`を使うことで、モデルの訓練中にバックグラウンドでデータセットがバッチを取得できます。
ds = ds.prefetch(buffer_size=AUTOTUNE)


# Same as above, but with train images.
# Do you spot a difference?
# Check out input_pipeline.get_data() in the editor at your right to see how the
# images are preprocessed differently.
batch = next(iter(ds.as_numpy_iterator()))


# Load model definition & initialize random parameters.
VisionTransformer = models.KNOWN_MODELS[model].partial(num_classes=20)
_, params = VisionTransformer.init_by_shape(
    jax.random.PRNGKey(0),
# 初期化のためにバッチの "num_local_devices "次元を破棄する。
    [(batch[0].shape, 'float32')])

# これには、実際の事前学習済みモデルの結果をロードするだけでなく、最終レイヤーを変更したり、位置的な埋め込みの
# サイズを変更するなど、パラメータを少し変更することも含まれます。
params = checkpoint.load_pretrained(
    pretrained_path=f'{model}.npz',
    init_params=params,
    model_config=models.CONFIGS[model],
    logger=logger,
)

# ここから下は何をやっているかよく分からなかった。おそらく、host memoryにあるデータをTPUで分散処理するため
# にレプリカを作成しているような雰囲気
params_repl = flax.jax_utils.replicate(params)
print('params.cls:', type(params['cls']).__name__, params['cls'].shape)
print('params_repl.cls:', type(params_repl['cls']).__name__, params_repl['cls'].shape)

vit_apply_repl = jax.pmap(VisionTransformer.call)
total_steps = 100
warmup_steps = 5
decay_type = 'cosine'
grad_norm_clip = 1
accum_steps = 64
base_lr = 0.03

# 詳しい内容はtarin.pyのmake_update_fnメソッドの中を見ること
update_fn_repl = train.make_update_fn(VisionTransformer.call, accum_steps)
# 今回はmomentum optimizerを使用する。
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)

lr_fn = hyper.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps)
update_rngs = jax.random.split(jax.random.PRNGKey(0), jax.local_device_count())

for step, batch, lr_repl in zip(
    tqdm.notebook.trange(1, total_steps + 1),
    ds.as_numpy_iterator(),
    lr_iter
):
  batch_dict = {
      'image':np.array([batch[0]]),
      'label':np.array([batch[1]])
  }
  opt_repl, loss_repl, update_rngs = update_fn_repl(
      opt_repl, lr_repl, batch_dict, update_rngs)

vit_apply_repl = jax.pmap(VisionTransformer.call)
data_test = pathlib.Path('./test')

test_image_paths = list(sorted(data_test.glob('*')))
test_image_paths = [str(path) for path in test_image_paths]

image_count = len(test_image_paths)
path_ds_test = tf.data.Dataset.from_tensor_slices(test_image_paths)
image_ds_test = path_ds_test.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

import matplotlib.pyplot as plt
#訓練データの確認
#plt.figure(figsize=(8,8))
#for n,image in enumerate(image_ds_test.take(4)):
#  plt.subplot(2,2,n+1)
#  plt.imshow(image)
#  plt.grid(False)
#  plt.xticks([])
#  plt.yticks([])
#  plt.show()

BATCH_SIZE = 100

# シャッフルバッファのサイズをデータセットとおなじに設定することで、データが完全にシャッフルされる
# ようにできます。
ds_test = image_ds_test.batch(BATCH_SIZE)
# `prefetch`を使うことで、モデルの訓練中にバックグラウンドでデータセットがバッチを取得できます。
ds_test = ds_test.prefetch(buffer_size=AUTOTUNE)


steps = image_count // BATCH_SIZE
for i, batch in zip(tqdm.notebook.trange(steps), ds_test.as_numpy_iterator()):
    if i == 0:
        predicted = vit_apply_repl(opt_repl.target, batch[np.newaxis, :])
        predicted = flax.nn.softmax(predicted)
        predict_prob = predicted[0]
    else:
        predicted = vit_apply_repl(opt_repl.target, batch[np.newaxis, :])
        predicted = flax.nn.softmax(predicted)
        predict_prob = np.append(predict_prob, predicted[0], axis=0)

submit_file = pd.DataFrame(predict_prob)

submit_file.index = [path.replace("test/", "") for path in test_image_paths]

submit_file.to_csv("result_file.csv", header=False)


22
25
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
25

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?