LoginSignup
14
13

More than 3 years have passed since last update.

[tf2.x] VisionTransformer(ViT)を用いて画像分類(ついでにiOSで動かしてみる)

Last updated at Posted at 2021-02-20

はじめに

この記事ではVisionTransformer(ViT)を用いて画像分類を行います。
ViTって何?って方はこちらの記事が参考になるかと思います。

簡単にいうと自然言語処理等で猛威を奮っているTransformerを画像に応用したものです。自分は画像の研究はしていませんでしたが、自然言語処理においてRNNが駆逐されていったのと同じように画像でもCNNが駆逐されていく日も近いのかもしれません。

以前はInceptionV3というものを使って遊びましたが、今回はViTを使って遊んでみたいと思います。

ViTのメリット・デメリット

InceptionV3を用いた画像分類と比較するとメリットは以下の通りだと思います。

  • CNNを用いずTransformerのみで計算を行うため、より高速に推論が可能
  • いろんなタスクでSoTAはViTがとってるらしい
  • 流行りそう

逆にデメリットは以下が挙げられると思います。

  • 空間使用量が(おそらく)大きい
    • 最終的にはiOS端末に乗せたいと考えているため、動かないと困る・・・
  • かなり新しいモデルのため情報がまだ少ない

実践!

今回の実験に用いるタスクは前回と同じ250種類の鳥の画像分類です。

まずViTを簡単に使うことができる vit-keras をインストールしましょう。
pip install vit-keras とかでインストールできます。詳しくは https://github.com/faustomorales/vit-keras

imagenetを用いてpre-trainしたものを以下のようにしてロードします。

ライブラリのバグなのか include_top=False にしていても pretrained_top=False を指定しないとエラーが出ました。


from vit_keras import vit

vit_model = vit.vit_l32(
    image_size=224,
    pretrained=True,
    include_top=False,
    pretrained_top=False,
)

モデルは以下のようにして定義します。
ViTの先端に1つ線形層をかませてsoftmaxを出力させるだけです。

finetune_at = 28

# 出力付近以外をフリーズ
for layer in vit_model.layers[:finetune_at - 1]:
    layer.trainable = False

# ノイズの追加
noise = GaussianNoise(0.01, input_shape=(224, 224, 3))

model = models.Sequential()
model.add(noise)
model.add(vit_model)
model.add(layers.Dense(num_classes, activation="softmax"))

また、学習率を動的に下げてより高い精度を狙いに行きます。
warm-upなど、いろいろ試しましたが、以下のシンプルなものが一番いい感じになりました。

# 7 epoch ごとに 0.1 かけする
def scheduler(epoch: int, lr: float) -> float:
    if epoch != 0 and epoch % 7 == 0:
        return lr * 0.1
    else:
        return lr
lr_scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

あとは以下のようにして学習させます。
trainのgeneratorやその他のコールバックは前回のものを参考にしてください。

history = model.fit(
          train_generator_augmented,
          epochs=100,
          validation_data=validation_generator,
          verbose=1, 
          shuffle=True,
          callbacks=[
              EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True),
              cp_callback,
              tb_callback,
              lr_scheduler_callback,
          ])

ピンクがViTで水色が前回のInceptionV3をファインチューニングしたモデルに同様のスケジューラやノイズを加えたものの結果になります。いずれもvalidationデータに対する結果です。

スクリーンショット 2021-02-21 3.34.24.png

スクリーンショット 2021-02-21 3.34.39.png

InceptionV3ベースのものでもなかなか健闘しており、Accuracyは最大0.98を示しています。
しかしViTベースのものが圧倒的に強いです。Accuracyは1epoch終わった時点でかなり高く、最大0.992を示しています。

さらにlossもViTのほうが低く、そして安定しています。
Accuracyが高くなることはもちろん重要なのですが、分類問題であるため、lossが低ければ低いほど、より高い確信度をもって当てに行っているため、より良いモデルができていると考えることができます。

実際、InceptionV3を用いたモデルをiOSに搭載し、カメラ画像から推論をさせると白い鳥ではない画像に対して99.8%のスコアで「シロフクロウ」であると言い張っていました。(すなわち鳥でない画像に簡単に反応してしまうモデルになってしまいました)
shiro_hukurou.gif

シロフクロウの参考画像(Wikipediaより) 確かに真っ白
500px-Snowy.owl.overall.arp.750pix.jpg

先にネタバレしてしまいますが、一方のViTベースのモデルは白い画像に対してもシロフクロウであると言わず、基本的に鳥ではない画像に対しては反応しなくなりました。つまりめっちゃいいモデルができてしまいました

tf-liteへ変換

さて、iOSに乗せることはできるのでしょうか?
先にネタバレしましたが、普通に乗りますし、普通に動きました(iPhone11 Proで実験)

変換は前回と同じ感じで大丈夫です。

まずはモデルのロードと最高のモデルでtestデータに対して推論させてみましょう。

checkpoint_file = "[一番いいモデルのパス]"

vit_model = vit.vit_l32(
    image_size=224,
    pretrained=True,
    include_top=False,
    pretrained_top=False
)

model = models.Sequential()
model.add(vit_model)
model.add(layers.Dense(num_classes, activation="softmax"))

model.compile(optimizer = optimizers.Adam(),
               loss = 'sparse_categorical_crossentropy',
               metrics = ['accuracy'])

model.load_weights(checkpoint_file)

model.trainable = False

loss, acc = model.evaluate_generator(test_generator)
print("Test Acc:", acc)

Test Acc: 0.9944000244140625

前回が0.988程度だったことを考えると凄まじい精度の向上ですね。
たかが0.5%、そう思ってませんか?(CV ケイスケ・ホンダ)

# saved_modelにする
model.save("saved_model/bird_model_vit")

import tensorflow as tf

tf.compat.v1.enable_eager_execution()

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./saved_model/bird_model_vit/')
converter.experimental_new_converter = True   #<--- Tensorflow v2.2.x以降を使用している場合は不要
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('./bird_model_vit.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - bird_model_vit.tflite")

あとはこれをゴニョゴニョしてiOSに乗せます。
iOSで実際にどう使うかはこのあたりを参考にしました。
えー、かなりつらい道のり、かつ、難解なコードになってしまったため、ここでは伏せておきます。。。

モデルのサイズは以前のものと比べて約12倍になっているため、アプリ起動時のロードが長くなった気がします。

動作画面

tori_zukan.gif
早送りしてるとかじゃないですよ。
このようにリアルタイムで推論ができていることがわかると思います。
体感時間も以前のものより高速です。

実物の鳥でやりたかったのですが、手元にいないので諦めました。
この画像の鳥がデータセットにある可能性も考慮(おそらくKaggleのデータセットは画像検索から持ってきたものであると推測される)してYouTubeで可愛い鳥の動画とかを調べてやってみましたが、そこでもうまくいきました。

おわりに

ViTすげえ。速くて高精度、ホントにCNNも消えるかも・・・?
まさにAttention is All You Needですね!

14
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
13