Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
53
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

CNN 転移学習とファインチューニング

転移学習

・学習済みのモデルを新しいタスクに使うこと
・学習済みモデルの重みをそのまま使う

学習済みモデルのインスタンス化

include_top=False
・分類器の部分は使わない
・畳み込みの部分だけ使う

→畳み込みの部分は画像の一般的・普遍的な特徴をとらえているが、分類器は以前学習したタスクに特化しているため


from keras.applications import ResNet50

conv_base = ResNet50(weights='imagenet',
                     include_top=False,
                     input_shape=(32, 32, 3))

モデル


from keras import models
from keras import layers

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.summary()

凍結

新しく付け足した分類器の重みの更新をしている間に、学習済みモデルの重みも一緒に更新されないように凍結する


# 凍結前の更新可能な重みの数
print('before', len(model.trainable_weights))

before 216

# ResNet50の重みを凍結して更新されないようにする
conv_base.trainable = False

# 凍結後の更新可能な重みの数
print('after', len(model.trainable_weights))

after 4

データと学習

from keras.datasets import cifar10
(X_train, y_train),(X_test, y_test) = cifar10.load_data()

from keras.utils import np_utils

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

X_train /= 255.0
X_test /= 255.0

# ワンホットベクトル
n_classes = 10
Y_train = np_utils.to_categorical(y_train, n_classes)
Y_test = np_utils.to_categorical(y_test, n_classes)

from keras import optimizers

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(lr=1e-3),
              metrics=['acc'])

history = model.fit(X_train, Y_train,
                    batch_size=50, epochs=20, validation_split=0.1)

ファインチューニング

・学習済みのモデルの最後の一部分の重みを再学習させて、
 新しいタスクについても適合できるようにすること

<最後の一部分に限る理由>

・前半部分は画像の特徴の一般的なことを捉えているため再学習させる必要がない

・最後の方の層は、より具体的な特徴を捉えるようになっている
 そのため、最後だけを再学習させることで新しいタスクに適合させることができる

・以前のタスクしか知らないモデルと新しいタスクしか知らない分類器のつなぎ目を滑らかにするというイメージもできる

#最後のresblock(res5a_branch2aから最後まで)のみ再学習させる
conv_base.trainable = True

set_trainable = False
for layer in conv_base.layers:
    if layer.name == 'res5a_branch2a':
        set_trainable = True
    if set_trainable:
        layer.trainable = True
    else:
        layer.trainable = False
# 更新される重みの数
print('after', len(model.trainable_weights))

after 44

コンパイルと学習

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(lr=1e-3),
              metrics=['acc'])

history = model.fit(X_train, Y_train,
                    batch_size=50, epochs=3,
                    verbose=1, validation_split=0.1)
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
53
Help us understand the problem. What are the problem?