LoginSignup
9
8

More than 3 years have passed since last update.

TensorFlow 2.2.0で未だにサポートされないEfficientNetをフラゲして、Food101にチャレンジする。

Posted at

EfficientNetとは?

2019年に登場と同時に各方面でState Of the Artの達成に利用された、画像認識の最強モデルの一つがEfficientNetです。Omiitaさんの記事、 2019年最強の画像認識モデルEfficientNet解説 をご覧ください。

以下が論文でも掲載されていた精度とパラメータ数の図です。頭一つ抜けていることがわかります。

EfficientNetをTensorFlowで試すには?

そんなEfficientNetですが、TensorFlowでResNetなどの各種著名モデルを提供しているtf.keras.applicationsには、まだ入っていないようです。2020年6月現在、TensorFlowの最新バージョンは2.2.0です。EfficientNetは既にTensorFlowのソースコード上には取り込まれ、存在しています。ですが、APIドキュメントでも言及がなく、APIもエクスポートされていません。まだ公式的には利用できない状態となっているようです。そこで本稿では、フライングゲットということで、Keras ApplicationsからEfficientNetを取り込んで動かしてみましょう。

また、せっかくソースコードを触る機会なので、活性化関数もEfficientNetで利用されているswishからMishへ取り替えて、自作のライブラリtftkにEfficientNetを取り込んでみます。

Keras ApplicationsからTensorFlowだけで動くように取り込んでみたソースコードはこちらです。

Food101データセットで試す。

それではEffiientNetB3(改)を使ってFood101データセットで学習してみます。Food101のデータセットや学習精度ベンチマークについては、こちらのkoshian2さんがInception v3で試した実例、普通の画像データセットに飽きたら、Food-101はいかが? を参考にさせて頂くのがよさそうです。

まずは学習の前に目標とする精度を決めておきます。先ほどのリンクを読むと、どうやら特別なニューラルネットワークを組めば 90%を超えるモデルができ、Inception V3で相当頑張ると86.97%くらいに届くようですね。さすがにスペシャルな一品ほどの精度はでないと思うので、こちらの86.97%を目指すことにしましょう。では、さっそく試していきましょう。今回は少ないデータ数でも精度が出やすい、ImageNetデータを学習済みの重みを活用した転移学習を行ってみます。それでは、EfficientNetでは、どうなるでしょうか。最強モデルを使って学習をすすめてみましょう。

まずは実装をしましょう。今回の学習の実装は次のようにできます。

まずはライブラリのインストール。

> !pip install tftk -U

そして、以下のコードを実行します。(これだけです)

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
import tftk
from tftk.image.dataset import Food101
from tftk.image.dataset import ImageDatasetUtil
from tftk.image.model import KerasEfficientNetB3

from tftk.train.image import ImageTrain
from tftk.image.augument import ImageAugument
from tftk.callback import CallbackBuilder
from tftk.optimizer import OptimizerBuilder

tftk.ENABLE_MIXED_PRECISION()
BATCH_SIZE = 40
CLASS_NUM = 101
IMAGE_SIZE = 224
CHANNELS = 3
EPOCHS = 150
SHUFFLE_SIZE = 1000

train, train_len = Food101.get_train_dataset()
validation, validation_len = Food101.get_validation_dataset()

train = train.map(ImageDatasetUtil.resize_and_max_square_crop(IMAGE_SIZE,IMAGE_SIZE),num_parallel_calls=tf.data.experimental.AUTOTUNE).map(ImageAugument.randaugment_map(1,4))
train = train.map(ImageDatasetUtil.image_reguralization(),num_parallel_calls=tf.data.experimental.AUTOTUNE).map(ImageDatasetUtil.one_hot(CLASS_NUM),num_parallel_calls=tf.data.experimental.AUTOTUNE).apply(ImageAugument.mixup_apply(200,0.1))
validation = validation.map(ImageDatasetUtil.resize_and_max_square_crop(IMAGE_SIZE,IMAGE_SIZE),num_parallel_calls=tf.data.experimental.AUTOTUNE).map(ImageDatasetUtil.image_reguralization(),num_parallel_calls=tf.data.experimental.AUTOTUNE).map(ImageDatasetUtil.one_hot(CLASS_NUM),num_parallel_calls=tf.data.experimental.AUTOTUNE)

optimizer = OptimizerBuilder.get_optimizer()
model = KerasEfficientNetB3.get_model(input_shape=(IMAGE_SIZE,IMAGE_SIZE,CHANNELS),classes=CLASS_NUM,weights="imagenet")
callbacks = CallbackBuilder.get_callbacks(tensorboard=True, consine_annealing=False, reduce_lr_on_plateau=True,reduce_patience=6,reduce_factor=0.25,early_stopping_patience=10)
ImageTrain.train_image_classification(train_data=train,train_size=train_len,batch_size=BATCH_SIZE,validation_data=validation,validation_size=validation_len,shuffle_size=SHUFFLE_SIZE,model=model,callbacks=callbacks,optimizer=optimizer,loss="categorical_crossentropy",max_epoch=EPOCHS)

上述のコードでは以下の2行でデータ拡張を行っています。RandAugument(拡張数1,大きさ4)、Mixupを使ってみます。

train = train.map(ImageDatasetUtil.resize_and_max_square_crop(IMAGE_SIZE,IMAGE_SIZE),num_parallel_calls=tf.data.experimental.AUTOTUNE).map(ImageAugument.randaugment_map(1,4))
train.map(ImageDatasetUtil.image_reguralization(),num_parallel_calls=tf.data.experimental.AUTOTUNE).map(ImageDatasetUtil.one_hot(CLASS_NUM),num_parallel_calls=tf.data.experimental.AUTOTUNE).apply(ImageAugument.mixup_apply(200,0.1))

モデルはKeras Applicationsから持ってきたモデルですので、事前学習された重みがあります。ImageNetを使って学習した重みを使って転移学習をさせてみます。

model = KerasEfficientNetB0.get_model(input_shape=(IMAGE_SIZE,IMAGE_SIZE,CHANNELS),classes=CLASS_NUM,weights="imagenet")

実行を開始すると、10分ほどFood101ダウンロードの時間がかかります。その後、学習が開始されます。以下が学習のログです。

Epoch 70/150
1893/1893 [==============================] - 563s 297ms/step - loss: 0.5034 - acc: 0.9558 - val_loss: 0.5993 - val_acc: 0.8561 - lr: 0.0025

そして、検証データでのval_accは、85.61%に到達しており、それほど工夫していないのですが、目標の86.97%までは一気にあと一息といったところまで到達することができましたね。流石に最強と言われるモデルです。なるほど、強いです。

9
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
8