Help us understand the problem. What is going on with this article?

[Keras/TensorFlow] 転移学習(Fine-tuning)

More than 3 years have passed since last update.

目的

ゼロからKerasとTensorFlow(TF)を自由自在に動かせるようになる。
そのための、End to Endの作業ログ(備忘録)を残す。
※環境はMacだが、他のOSでの汎用性を保つように意識。
※アジャイルで執筆しており、精度を逐次高めていく予定。

目次

概要

このページを読んでできるようになること

  • VGG16のFine-tuningによる17種類の花の分類 で紹介されている fine tuning のサンプルプログラムを動かす。
  • 上記の学習はCPUだと2日間近くかかるため、事前に用意してある学習済みのパラメータを読み込み、学習結果を確認する。

転移学習(Fine-tuning)とは

Chainerでファインチューニングするときの個人的ベストプラクティス より抜粋

ニューラルネットを学習するために、別の問題、別のデータセットで学習済みのモデルのパラメータをコピーして、それを新しいニューラルネットのパラメータの初期値として使うことをファインチューニングといいます。
典型的なケースとして、一般物体認識のデータセットであるImageNetで学習したネットワークを物体検出やSemantic Segmentationといった別の問題に使用するといった例があります。

一般的にDeep Learningでは大量の学習データが必要とされていますが、あらかじめ(大量のデータで)学習したモデルを初期値として使いファインチューニングすることで、実際に解きたい問題に関するデータの量が不十分でも十分な性能を達成できる場合があります。また、学習にかかる時間を短縮する効果もあります。

ソースコード

https://github.com/deer-dslab/keras-example

動作確認環境

macOS Sierra
Python 3.4
Anaconda3
TensorFlow 1.0.1

Mac で動作確認していますが、Windows でも 事前準備の brew コマンド以外は同様だと思います。

前提条件

事前準備

サイズの大きいバイナリファイルを扱うための Git LFS 導入 しておく

$ brew install git-lfs

ディレクトリ構成

keras-example/
 ├ fine-tuning.py
 └ dataset/
   ├ jpg/
   └ labels.txt/

ダウンロード

ソースコード

リポジトリ keras-example を GitLFS を使用して clone する

$ git lfs clone https://github.com/deer-dslab/keras-example.git

データセット

  1. 17 Category Flower Dataset ページの Downloads から Dataset images をクリックして 17flowers.tgz をダウンロードする
  2. 17flowers.tgz を解凍して出てきた jpg/ ディレクトリを dataset/ 下に配置する

前処理

  1. setup.py を実行して、画像データを学習用とテスト用に分ける
$ cd path-to-project/keras-example/fine-tuning
$ python setup.py

転移学習の実行(fine-tuning)

fine-tuning.py を実行する

$ cd path-to-project/keras-example/fine-tuning
$ python fine-tuning.py

そして (CPU環境では) 2日弱待ちます。。
ちなみに、実行中のキャンセルは例によって ctrl + C です

コード解説

fine-tuning.py
    vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
    # vgg16.summary()

    # FC層を構築
    # Flattenへの入力指定はバッチ数を除く
    top_model = Sequential()
    top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
    top_model.add(Dense(256, activation='relu'))
    top_model.add(Dropout(0.5))
    top_model.add(Dense(nb_classes, activation='softmax'))

    # 学習済みのFC層の重みをロード
    # top_model.load_weights(os.path.join(result_dir, 'bottleneck_fc_model.h5'))

    # VGG16とFCを接続
    model = Model(input=vgg16.input, output=top_model(vgg16.output))
  • include_top はVGG16のトップにある1000クラス分類するフル結合層(FC)を含むか含まないかを指定する (※ KerasでVGG16を使う より抜粋)。 今回は VGG16 の 1000 クラス分類ではなく、出力層を 17 クラス分類にするため include_top=False とし、新たに定義した出力層 top_model を使用している。

学習結果を試す

2日待てない場合は、学習後のパラメータファイル finetuning-sample.h5keras-example/fine-tuning/dataset/results/ に置いてあります。

これを読み込んで動作確認してみます。

$ cd path-to-project/keras-example/fine-tuning
$ python predict.py dataset/test_images/Tulip/image_0003.jpg 

('Tulip', 0.99381989)
('Daffodil', 0.0060764253)
('Sunflower', 5.0716473e-05)
('Windflower', 3.6588597e-05)
('Cowslip', 1.2691652e-05)

99% の確率でチューリップという結果がでました。

コード解説

学習時と同様にモデルを定義した後、パラメータ finetuning-sample.h5 をロードして上書きしている。

predict.py
# VGG16とFCを接続
model = Model(input=vgg16.input, output=fc(vgg16.output))

# 学習済みの重みをロード
model.load_weights(os.path.join(result_dir, 'finetuning-sample.h5'))

参考資料

目次

http://qiita.com/agumon/items/91f897b7f260f6aeca95

参考サイト

その他

VGG16とは

tshimba
このサイトの掲載内容は私自身の見解であり、所属する組織とは関係ありません
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away