#はじめに
転移学習でCIFAR-10を対象とする記事はネット上に結構あるが、最終的な精度は低いままで学習を終了させてしまう記事が多いようだ。この記事では学習終了時の正解率に拘って、転移学習(+ファインチューニング)で高性能なモデルの作成方法を記事にする。
TensorFlow/Keras環境での記事だが、他のフレームワークでもある程度応用できるはずだ。
ちなみに、普通にCIFAR10のデータだけで学習させた場合は認識率97%までいけばかなりいい方で、99%は転移学習を使わないと達成は難しいのではないかと思う。このサイトの集計によると、記事作成時の最高記録はEfficientNetL2(SAM)で99.7%となっている。
環境
Google Colab TPU
TensorFlow : 2.4.0
tf.keras: 2.3.0
#モデルの作成
##ベースとなるモデルを選ぶ
tf.keras.applicationsに訓練済みの重みを使用できるモデルがいくつか提供されているので、それを使う。
どれでもいいのだが、モデルによって学習のさせ方(学習率など)が変わってきたりするので、適宜それに合わせて調整する必要がある。ResNetやEfficentNetは、学習させやすいようだ。
EfficientNetはモデルの大きさに関する選択肢が多いこともあり、この記事ではEfficientNetを使う。論文ではCIFAR10の転移学習の結果も示されており、EfficientNetB0では98.1%、EfficientNetB7では98.9%なので、目標の99%は頑張ればできそうな数値だ。
##入力画像の大きさをモデルに合わせる
転移学習では当然学習済みのモデルを使うわけだが、その学習は大抵ImageNetで実施されており、その画像サイズは224x224である。モデルの構造も学習の重みもこのサイズを前提にしているが、CIFAR-10は32x32でサイズが大きく違うので、そのまま画像を入力してもうまく機能しない。
したがって、入力画像をリサイズして元のモデルに合わせる処理を入れる。全く同じサイズにする必要もないができるだけ合わせた方が、やはり良い性能が出るようだ。リサイズ処理はモデル内で実行するように実装できる。
EfficientNetは各モデルで入力サイズが違うので、事前に調べてそれに合わせる。関連記事
##転移学習用にフラグを設定する
転移学習では学習させるレイヤーを選択して一部のみ学習可能にする操作が必要になる。
TensorFlowではModel(またはLayer)に対し"trainable=False"としておくと、学習で重みが更新されなくなる。この作業はフリーズと呼ばれる。
また、tensorflowの転移学習のTutorialにも書いてあるが、学習済みモデルを追加する際に、"training=False"として組み込むことが推奨されている。これはBatchNormalization関連の処理に影響するようだ。これは上述の"trainable"とは別の設定になる。
##トップレイヤーを追加する
学習済みモデルの出力付近のレイヤー(トップレイヤー)は不要なので取り除き、そこに新たにCIFAR10用のレイヤーを追加する。ここはよくあるGlobalAveragePoolingとDenseの組み合わせでいいのだが、間にDropoutを追加すると性能が上がる場合がある。この記事ではrate=0.5として挿入している。
##モデル実装例
これまでの内容に沿って実装していくと概ね以下のようになる。
num_classes = 10
input_shape = (32,32,3)
base_input_shape = (224,224,3)
model_class = tf.keras.applications.EfficientNetB0
x = inputs = tf.keras.layers.Input(shape=input_shape)
#Resize
x = tf.keras.layers.Lambda(lambda image: tf.image.resize(image, base_input_shape[0:2]), output_shape=base_input_shape)(x)
#学習済みモデルの組み込み
base_model = model_class(include_top=False, input_shape=base_input_shape,
weights='imagenet')
base_model.trainable = False
x = base_model(x, training=False) # trainingをFalseにする
#トップレイヤーの追加
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(num_classes)(x)
outputs = tf.keras.layers.Activation('softmax')(x)
model = tf.keras.Model(inputs, outputs)
#トレーニング
##適切な最適化アルゴリズムを選ぶ
どれでもそれなりに学習できるのだが、ここは無難にSGD(momentum=0.9)を使う。各種CNNモデルが提案された論文を読むと大抵SGDを使っており、Adamなどは普通使われない。
転移学習+ファインチューニングでもAdamよりもSGDを使った方が最終的な精度は良いようだ。ただしAdamの方が学習自体は速いので、学習時間を重視する場合はAdamなどを使っても良いだろう。
##入力画像の前処理(正規化)をする
各モデルの学習時には入力画像に対して前処理(正規化)が行われているので、転移学習の際にもそれに合わせておく必要がある。
tf.kerasではこの前処理に関して、各モデルでpreprocess_inputという名前の関数を公開することになっているので、それを使う。
これはモデル内に入れてもいいのだが、この記事ではdatasetでの前処理として実装。
##データ拡張を行う
転移学習では大抵元のデータセットに比べて訓練データの量が少なくなるので、過学習になりやすくなる。したがって、精度をあげるためにはデータ拡張は必須となる。
ここでは、よくある左右フリップと上下左右のShiftに加えて、Cutoutのサイズを0.4として2回と、Saturation/Contrastをランダムに変更する処理を入れた。
やりすぎると学習が進まなくなったりするので、状況に応じてCutoutのサイズや回数を調節するなど、最適な設定を探す事になる。
##最初はトップのみ学習
CIFAR10用にトップを付け替えたが、ここは初期状態のままなので、最初はこの部分だけ学習させて学習済みのレイヤーと馴染ませる。これをやっておくと次のFine Tuningで学習率を上げてもモデルが崩壊しづらくなる。
ここは厳密にやる必要はないので5エポックで終わりとした。
狭義の転移学習はここで終わりだが、流石にこのままでは高い正解率は達成できないのでファインチューニングを行う。
##ファインチューニング
トップレイヤーがある程度馴染んだら、学習済みレイヤーも訓練対象として再学習を行う。ここは出力に近い部分だけ部分的にフリーズを解除する、としている記事が多いが、全部解除しても構わない。モデルに対して"trainable=true"とすると、全てのフリーズが解除される。
筆者が試した限りでは全部フリーズ解除して全ての重みを訓練対象とした方が性能がよかった。ただし、訓練にかかる時間は増える。
全て訓練対象とするとモデルが崩壊する可能性も増えるのだが、上述のトップレイヤーを事前に馴染ませる作業をしておくと、可能性が減るようだ。また徐々に学習率をあげるようにWarmupを適用しておくことも、崩壊を防ぐ効果があるようだ。ファインチューニングでの学習率はかなり低く抑えることが一般的なようだが、低すぎると学習が進まないので、ある程度高くして学習時間を短くする。
本記事の実装では、コサインカーブで学習率を上昇(Warmup)させて、一定期間維持した後にコサインカーブで学習率を減少(Cooldown)させて学習終了とさせる。
学習率やバッチサイズに応じて最適なエポック数は変わってくるが、ここでは基本的に下記のような設定で学習させた。
項目 | 値 |
---|---|
バッチサイズ | 200 |
Warmup | 10 epochs |
Flat | 5 epochs |
Cooldown | 20 epochs |
最小学習率 | 0.001 |
最大学習率 | 0.025 |
というわけで、合計5+10+5+20エポックで学習終了とすると、EfficentNetB5で99%まで到達する。モデルが大きくなるとエポック数を増やしたほうが良いが、EffficientNetB5の場合はこのエポック数でTPUでは2時間半ほどかかる。
以下、実際に訓練させた際のLossとAccuracyのグラフ。
ここに掲載した設定ではギリギリで99%だが、もう少し学習時間を長くすれば確実にいける。EfficientNetB5での筆者の最高記録は99.12%で、この場合はエポック数75。
ちなみにEfficientNetB0での最高記録は98.31%だった。
#参考
Transfer learning and fine-tuning
データのお気持ちを考えながらData Augmentationする