はじめに
この記事ではVisionTransformer(ViT)を用いて画像分類を行います。
ViTって何?って方はこちらの記事が参考になるかと思います。
簡単にいうと自然言語処理等で猛威を奮っているTransformerを画像に応用したものです。自分は画像の研究はしていませんでしたが、自然言語処理においてRNNが駆逐されていったのと同じように画像でもCNNが駆逐されていく日も近いのかもしれません。
以前はInceptionV3というものを使って遊びましたが、今回はViTを使って遊んでみたいと思います。
ViTのメリット・デメリット
InceptionV3を用いた画像分類と比較するとメリットは以下の通りだと思います。
- CNNを用いずTransformerのみで計算を行うため、より高速に推論が可能
- いろんなタスクでSoTAはViTがとってるらしい
- 流行りそう
逆にデメリットは以下が挙げられると思います。
- 空間使用量が(おそらく)大きい
- 最終的にはiOS端末に乗せたいと考えているため、動かないと困る・・・
- かなり新しいモデルのため情報がまだ少ない
実践!
今回の実験に用いるタスクは前回と同じ250種類の鳥の画像分類です。
まずViTを簡単に使うことができる vit-keras
をインストールしましょう。
pip install vit-keras
とかでインストールできます。詳しくは https://github.com/faustomorales/vit-keras
imagenetを用いてpre-trainしたものを以下のようにしてロードします。
ライブラリのバグなのか include_top=False
にしていても pretrained_top=False
を指定しないとエラーが出ました。
from vit_keras import vit
vit_model = vit.vit_l32(
image_size=224,
pretrained=True,
include_top=False,
pretrained_top=False,
)
モデルは以下のようにして定義します。
ViTの先端に1つ線形層をかませてsoftmaxを出力させるだけです。
finetune_at = 28
# 出力付近以外をフリーズ
for layer in vit_model.layers[:finetune_at - 1]:
layer.trainable = False
# ノイズの追加
noise = GaussianNoise(0.01, input_shape=(224, 224, 3))
model = models.Sequential()
model.add(noise)
model.add(vit_model)
model.add(layers.Dense(num_classes, activation="softmax"))
また、学習率を動的に下げてより高い精度を狙いに行きます。
warm-upなど、いろいろ試しましたが、以下のシンプルなものが一番いい感じになりました。
# 7 epoch ごとに 0.1 かけする
def scheduler(epoch: int, lr: float) -> float:
if epoch != 0 and epoch % 7 == 0:
return lr * 0.1
else:
return lr
lr_scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
あとは以下のようにして学習させます。
trainのgeneratorやその他のコールバックは前回のものを参考にしてください。
history = model.fit(
train_generator_augmented,
epochs=100,
validation_data=validation_generator,
verbose=1,
shuffle=True,
callbacks=[
EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True),
cp_callback,
tb_callback,
lr_scheduler_callback,
])
ピンクがViTで水色が前回のInceptionV3をファインチューニングしたモデルに同様のスケジューラやノイズを加えたものの結果になります。いずれもvalidationデータに対する結果です。
InceptionV3ベースのものでもなかなか健闘しており、Accuracyは最大0.98を示しています。
しかしViTベースのものが圧倒的に強いです。Accuracyは1epoch終わった時点でかなり高く、最大0.992を示しています。
さらにlossもViTのほうが低く、そして安定しています。
Accuracyが高くなることはもちろん重要なのですが、分類問題であるため、lossが低ければ低いほど、より高い確信度をもって当てに行っているため、より良いモデルができていると考えることができます。
実際、InceptionV3を用いたモデルをiOSに搭載し、カメラ画像から推論をさせると白い鳥ではない画像に対して99.8%のスコアで「シロフクロウ」であると言い張っていました。(すなわち鳥でない画像に簡単に反応してしまうモデルになってしまいました)
シロフクロウの参考画像(Wikipediaより) 確かに真っ白
先にネタバレしてしまいますが、一方のViTベースのモデルは白い画像に対してもシロフクロウであると言わず、基本的に鳥ではない画像に対しては反応しなくなりました。つまりめっちゃいいモデルができてしまいました。
tf-liteへ変換
さて、iOSに乗せることはできるのでしょうか?
先にネタバレしましたが、普通に乗りますし、普通に動きました(iPhone11 Proで実験)
変換は前回と同じ感じで大丈夫です。
まずはモデルのロードと最高のモデルでtestデータに対して推論させてみましょう。
checkpoint_file = "[一番いいモデルのパス]"
vit_model = vit.vit_l32(
image_size=224,
pretrained=True,
include_top=False,
pretrained_top=False
)
model = models.Sequential()
model.add(vit_model)
model.add(layers.Dense(num_classes, activation="softmax"))
model.compile(optimizer = optimizers.Adam(),
loss = 'sparse_categorical_crossentropy',
metrics = ['accuracy'])
model.load_weights(checkpoint_file)
model.trainable = False
loss, acc = model.evaluate_generator(test_generator)
print("Test Acc:", acc)
Test Acc: 0.9944000244140625
前回が0.988程度だったことを考えると凄まじい精度の向上ですね。
たかが0.5%、そう思ってませんか?(CV ケイスケ・ホンダ)
# saved_modelにする
model.save("saved_model/bird_model_vit")
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./saved_model/bird_model_vit/')
converter.experimental_new_converter = True #<--- Tensorflow v2.2.x以降を使用している場合は不要
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('./bird_model_vit.tflite', 'wb') as w:
w.write(tflite_quant_model)
print("Weight Quantization complete! - bird_model_vit.tflite")
あとはこれをゴニョゴニョしてiOSに乗せます。
iOSで実際にどう使うかはこのあたりを参考にしました。
えー、かなりつらい道のり、かつ、難解なコードになってしまったため、ここでは伏せておきます。。。
モデルのサイズは以前のものと比べて約12倍になっているため、アプリ起動時のロードが長くなった気がします。
動作画面
早送りしてるとかじゃないですよ。
このようにリアルタイムで推論ができていることがわかると思います。
体感時間も以前のものより高速です。
実物の鳥でやりたかったのですが、手元にいないので諦めました。
この画像の鳥がデータセットにある可能性も考慮(おそらくKaggleのデータセットは画像検索から持ってきたものであると推測される)してYouTubeで可愛い鳥の動画とかを調べてやってみましたが、そこでもうまくいきました。
おわりに
ViTすげえ。速くて高精度、ホントにCNNも消えるかも・・・?
まさにAttention is All You Needですね!