LoginSignup
13
14

More than 3 years have passed since last update.

Tensorflowのみで転移学習(No Keras)

Last updated at Posted at 2018-11-16

初めに

Tensorflowで転移学習を利用しようと思い検索をかけたところ,Kerasを使った記事ばかり出てきた.KerasはTensorflowバックエンドなので間違いではないが,個人的にKerasだとTensorflowのソースコードとの互換性やらデータの受け渡しやらで拡張性が薄く,個人的に都合が悪いのでできるだけTensorflowのみで行いたい.そこで資料がなくかなり悩んだので記録しておく.今回の記事は以下のリンクのものを参考にさせていただいた.

Use Keras Pretrained Models With Tensorflow

今回作成したソースコードの全体像はこちら参照

TensorflowからKerasAPIの利用

結論から言うとTensorflowのモジュールのみで学習済みモデルを読み込むことはできない.そのため学習済みモデルを利用するにはKerasのAPIをTensorflowから読み込むことになる.Tensorflowにはkerasのアプリケーションを呼ぶクラスtf.keras.applicationsがあるので,こちらを利用して学習済みモデルを呼び出す.どうやらこちらにあるネットワークは全て利用できそう.

利用できそうなネットワークは以下の通り(2018/11/17現在).

  • densenet
  • imagenet utils
  • inception resnet v2
  • inception v3
  • mobilenet
  • nasnet
  • resnet50
  • vgg16
  • vgg19
  • xception

例えばこの中のVGG19のモデルをTensorflow上で利用する場合は以下のように呼び出す.

vgg19 = tf.keras.applications.VGG19(include_top=True,
                                    weights='imagenet',
                                    input_tensor=None,
                                    input_shape=None,
                                    pooling=None,
                                    classes=1000)

それぞれのオプションの意味をこちらを参考にしてまとめた.

  • include_top
    入力層を含めるかどうか,imagenetの重みをダウンロードした場合は,入力が[224,224,3]に固定される.転移学習などで独自の画像を学習したい時は画像サイズに合わせてTrue or Falseを選択.
  • weights
    imagenet or None or モデルのパスを選択して入力する.imagenetの学習済みモデルを利用したい場合はimagenetを指定するとPCにダウンロードされる.
  • input_shape
    include_top = Falseの時のみ必要.入力画像の(x,y,チャネル数)weightsにimagenetを指定した場合,チャネルが3に固定されるなど,読み込むモデルに応じて制限がある場合がある.
  • input_tensor
    ネットワークに入力するデータ.
  • pooling
    include_topがFalseの時のみ指定可能.最後の特徴抽出の層のmax_poolingの手法を以下の3つから選択できる.
    • None
      通常のmaxpooling,出力は4次元のTensor
    • avg
      Global average pooling,出力は2次元のTensor
    • max
      Global max pooling,出力は2次元のTensor
  • classes
    include_topがTrueの時のみ指定可能.読み込むウェイトファイルのクラス分類数.ImageNetの時は1000.

これにより呼び出した学習済みモデルの出力層を利用することで転移学習行える。,

転移学習の例

今回は,VGG19のimagenetを学習したモデルを用いて転移学習を行う.imagenetを学習した重みを用いて特徴抽出を行い,FC層のみ自前で作成し学習を行う.VGG19の事前学習済みモデルをオプションに従って以下のように修正した.

./model.py
vgg19 = tf.keras.applications.VGG19(include_top=False,
                                    weights='imagenet',
                                    input_shape=(self.input_size[0], self.input_size[1], self.channel),
                                    input_tensor=input,
                                    pooling=None)

Tensorflowの学習においてパラメータを更新たくない場合,レイヤーのtrainable = Falseに設定することで学習中にパラメータを更新しなくなる.転移学習を行う場合はtrainable = Trueにすれば良い.今回はVGG19を転移学習の特徴抽出器として用いるためtrainable = Falseを適応する.

./model.py
for vgg_layer in vgg19.layers:
  vgg_layer.trainable = False

vgg19_out = tf.identity(vgg19.layers[-1].output, name='output')

ここで得たvgg19_outが抽出された特徴量となる.この特徴量を利用し,全結合層のみ学習を行うモデルを作成する.最終的なVGG19の転移学習を行うモデル構築の部分のソースコードは以下の通り.

./model.py
    def vgg19(self , input , reuse = False , freeze = True,name =''):
        '''
        VGG19
        '''
        print('[SETUPMODEL]\tVGG19')
        with tf.variable_scope("VGG19" + name , reuse=reuse) as scope:
            if reuse: scope.reuse_variables()
            with tf.variable_scope("VGG19"):
                vgg19     = tf.keras.applications.VGG19(include_top=False,
                                                        weights='imagenet',
                                                        input_shape=(self.input_size[0], self.input_size[1], self.channel),
                                                        input_tensor=input
                                                        )
                if freeze:# 下位層のみ学習する場合
                    for vgg_layer in vgg19.layers:
                        vgg_layer.trainable = False

                vgg19_out = tf.identity(vgg19.layers[-1].output, name='output')

            '''
            自作でFC層のみ学習する.
            '''
            with tf.variable_scope("FC{0}".format(1)):
                flat       = self.flatten(vgg19_out,self.nBatch)
                fc1 = self.fc( flat , 4096 , "1",BatchNorm = False)
                fc2 = self.fc( fc1 ,  4096 , "2",BatchNorm = False)
                fc3 = self.fc( fc2 ,  1000 , "3",BatchNorm = False)
                output = self.fc( fc3 , self.labelSize , "out",BatchNorm = False)

        return output

実行

ソースコードの実行は以下のコマンドで行う.

python fit.py --img_path={実行したいフォルダのパス} \
              --train_num={学習に用いる画像の枚数} \
              --gray={グレースケールにするかどうか(True or False)} \
              --test_num={テスト用の画像の枚数} \
              --learning_rate={学習率} \
              --batch_size={minibatchサイズ} \ 
              --max_epoch={最大学習数}

終わりに

正直資料がなさすぎてこちらの方の記事がなかったらわからなかった.
公式tf.keras.applicationsのドキュメントもかなり薄いので,Tensorflow2.0で削除されるのかもしれない.

13
14
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
14