Kerasでpre-trainingする方法&multi-inputするときのvalidation設定方法

Pre-trainingした重みをFine-tuningする方法

KerasでPre-trainingした重みを使ってFine-tuningする方法がいくつかありそうで調べてみました。例えば、この記事。

Kerasで学ぶ転移学習
https://elix-tech.github.io/ja/2016/06/22/transfer-learning-ja.html

この方法でもいいのですが、プログラムが長くなるので、もっとすっきり書けないかなと思ってやってみました。今回はKeras公式ブログにある文書分類のCNNモデルでやってみます。
https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html

まずは事前学習したモデルファイルを読み込みます。

pretrain_best_model = load_model("pretrain_CNN.hdf5")
weights = []
for num, layer in enumerate(pretrain_best_model.layers):
    print(num, layer)
    weights.append(layer.get_weights())

そして、以下のように、レイヤーを定義するときに引数で事前学習した重みをセットします。

sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32' , name='sequence_input')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(128, 5, activation='relu', weights=weights[2])(embedded_sequences)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu', weights=weights[4])(x)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu', weights=weights[6])(x)
x = MaxPooling1D(35)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
preds = Dense(len(labels_index), activation='softmax')(x)
model = Model(sequence_input, preds)
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['acc'])

あとは、普通にfitすればFine-tuningできます。

model.fit(x_train, y_train, validation_data=(x_val, y_val),
          epochs=2, batch_size=128)

Multi inputでvalidationデータをセットする

あと、ネットワークに複数のデータを入れたときにvalidationデータをどうやって渡すのか、Kerasのreferenceを読んでもあまりちゃんと例が載ってなかったので試行錯誤してみました。その結果、以下でいけました。モデルは上と同じくKeras公式ブログのCNNと、もう1つFull Connectの別の入力をマージするというものです。

# CNNモデルはさっきと同じ
from keras.layers import concatenate

sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32' , name='sequence_input_ft')
embedded_sequences_ft = embedding_layer(sequence_input)
x = Conv1D(128, 5, activation='relu', weights=weights[2])(embedded_sequences)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu', weights=weights[4])(x)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu', weights=weights[6])(x)
x = MaxPooling1D(35)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)

#もう1つ別のネットワークを入力にする。こっちはMLP。
mlp_input = Input(shape=(100,), dtype="float32", name='mlp_input')
x_mlp = Dense(32, activation='relu')(mlp_input)
x_mlp = Dense(32, activation='relu')(x_mlp)

# ここで2つのネットワークをマージ
x_m = concatenate([x, x_mlp], axis=-1)
#ここでは最終的に2値分類する
preds = Dense(2, activation='softmax')(x_m)

sgd = optimizers.SGD()
model = Model(inputs=[sequence_input, mlp_input], outputs=preds)
model.compile(loss='binary_crossentropy',
              optimizer=sgd,
              metrics=['acc'])

ここまでできたらfitで学習するのですが、そのときにinputと同じようにvalidation_data=({'sequence_input': x_val, 'mlp_input': mlp_val}, y_val)と設定します。全体としては以下のような感じです。

history = model.fit({'sequence_input': x_train, 'mlp_input': mlp_train}, y_train,
                                batch_size=32,
                                epochs=10, 
                                validation_data=({'sequence_input': x_val, 'mlp_input': mlp_val}, y_val),
                                callbacks=callbacks)

Kerasは便利ですね!

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.