LoginSignup
44
55

More than 5 years have passed since last update.

CNN 転移学習とファインチューニング

Posted at

転移学習

・学習済みのモデルを新しいタスクに使うこと
・学習済みモデルの重みをそのまま使う

学習済みモデルのインスタンス化

include_top=False
・分類器の部分は使わない
・畳み込みの部分だけ使う

→畳み込みの部分は画像の一般的・普遍的な特徴をとらえているが、分類器は以前学習したタスクに特化しているため


from keras.applications import ResNet50

conv_base = ResNet50(weights='imagenet',
                     include_top=False,
                     input_shape=(32, 32, 3))

モデル


from keras import models
from keras import layers

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.summary()

凍結

新しく付け足した分類器の重みの更新をしている間に、学習済みモデルの重みも一緒に更新されないように凍結する


# 凍結前の更新可能な重みの数
print('before', len(model.trainable_weights))

before 216

# ResNet50の重みを凍結して更新されないようにする
conv_base.trainable = False

# 凍結後の更新可能な重みの数
print('after', len(model.trainable_weights))

after 4

データと学習

from keras.datasets import cifar10
(X_train, y_train),(X_test, y_test) = cifar10.load_data()

from keras.utils import np_utils

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

X_train /= 255.0
X_test /= 255.0

# ワンホットベクトル
n_classes = 10
Y_train = np_utils.to_categorical(y_train, n_classes)
Y_test = np_utils.to_categorical(y_test, n_classes)

from keras import optimizers

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(lr=1e-3),
              metrics=['acc'])

history = model.fit(X_train, Y_train,
                    batch_size=50, epochs=20, validation_split=0.1)

ファインチューニング

・学習済みのモデルの最後の一部分の重みを再学習させて、
 新しいタスクについても適合できるようにすること

<最後の一部分に限る理由>

・前半部分は画像の特徴の一般的なことを捉えているため再学習させる必要がない

・最後の方の層は、より具体的な特徴を捉えるようになっている
 そのため、最後だけを再学習させることで新しいタスクに適合させることができる

・以前のタスクしか知らないモデルと新しいタスクしか知らない分類器のつなぎ目を滑らかにするというイメージもできる

#最後のresblock(res5a_branch2aから最後まで)のみ再学習させる
conv_base.trainable = True

set_trainable = False
for layer in conv_base.layers:
    if layer.name == 'res5a_branch2a':
        set_trainable = True
    if set_trainable:
        layer.trainable = True
    else:
        layer.trainable = False
# 更新される重みの数
print('after', len(model.trainable_weights))

after 44

コンパイルと学習

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(lr=1e-3),
              metrics=['acc'])

history = model.fit(X_train, Y_train,
                    batch_size=50, epochs=3,
                    verbose=1, validation_split=0.1)
44
55
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
44
55