LoginSignup
3

More than 5 years have passed since last update.

Kerasでpre-trainingする方法&multi-inputするときのvalidation設定方法

Last updated at Posted at 2018-03-15

Pre-trainingした重みをFine-tuningする方法

KerasでPre-trainingした重みを使ってFine-tuningする方法がいくつかありそうで調べてみました。例えば、この記事。

Kerasで学ぶ転移学習
https://elix-tech.github.io/ja/2016/06/22/transfer-learning-ja.html

この方法でもいいのですが、プログラムが長くなるので、もっとすっきり書けないかなと思ってやってみました。今回はKeras公式ブログにある文書分類のCNNモデルでやってみます。
https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html

まずは事前学習したモデルファイルを読み込みます。

pretrain_best_model = load_model("pretrain_CNN.hdf5")
weights = []
for num, layer in enumerate(pretrain_best_model.layers):
    print(num, layer)
    weights.append(layer.get_weights())

そして、以下のように、レイヤーを定義するときに引数で事前学習した重みをセットします。


sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32' , name='sequence_input')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(128, 5, activation='relu', weights=weights[2])(embedded_sequences)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu', weights=weights[4])(x)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu', weights=weights[6])(x)
x = MaxPooling1D(35)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
preds = Dense(len(labels_index), activation='softmax')(x)
model = Model(sequence_input, preds)
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['acc'])

あとは、普通にfitすればFine-tuningできます。


model.fit(x_train, y_train, validation_data=(x_val, y_val),
          epochs=2, batch_size=128)

Multi inputでvalidationデータをセットする

あと、ネットワークに複数のデータを入れたときにvalidationデータをどうやって渡すのか、Kerasのreferenceを読んでもあまりちゃんと例が載ってなかったので試行錯誤してみました。その結果、以下でいけました。モデルは上と同じくKeras公式ブログのCNNと、もう1つFull Connectの別の入力をマージするというものです。

# CNNモデルはさっきと同じ
from keras.layers import concatenate

sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32' , name='sequence_input_ft')
embedded_sequences_ft = embedding_layer(sequence_input)
x = Conv1D(128, 5, activation='relu', weights=weights[2])(embedded_sequences)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu', weights=weights[4])(x)
x = MaxPooling1D(5)(x)
x = Conv1D(128, 5, activation='relu', weights=weights[6])(x)
x = MaxPooling1D(35)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)

#もう1つ別のネットワークを入力にする。こっちはMLP。
mlp_input = Input(shape=(100,), dtype="float32", name='mlp_input')
x_mlp = Dense(32, activation='relu')(mlp_input)
x_mlp = Dense(32, activation='relu')(x_mlp)

# ここで2つのネットワークをマージ
x_m = concatenate([x, x_mlp], axis=-1)
#ここでは最終的に2値分類する
preds = Dense(2, activation='softmax')(x_m)

sgd = optimizers.SGD()
model = Model(inputs=[sequence_input, mlp_input], outputs=preds)
model.compile(loss='binary_crossentropy',
              optimizer=sgd,
              metrics=['acc'])

ここまでできたらfitで学習するのですが、そのときにinputと同じようにvalidation_data=({'sequence_input': x_val, 'mlp_input': mlp_val}, y_val)と設定します。全体としては以下のような感じです。


history = model.fit({'sequence_input': x_train, 'mlp_input': mlp_train}, y_train,
                                batch_size=32,
                                epochs=10, 
                                validation_data=({'sequence_input': x_val, 'mlp_input': mlp_val}, y_val),
                                callbacks=callbacks)

Kerasは便利ですね!

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3