4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Autoencoderで訓練済みモデルから中間層encoderの出力方法

Posted at

後で利用したい/出力したいレイヤーに名前を付ける。

サンプルコード

from keras.utils import plot_model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model,load_model

#変数定義
x_train = #(省略)
x_test = #(省略)
epoch = 50
batch_size = 100
model = define_model()

#モデル訓練
model.fit(x=x_train,
          y=x_train,
          epochs=epoch,
          batch_size=batch_size,
          shuffle=True
)

#モデル保存
model.save('aaa.h5')

#モデルロード
model = load_model('aaa.h5')

#中間層(encoder)出力
encoder = Model(inputs=model.input, outputs=model.get_layer('encoder_layer').output)

#評価
result = encoder.predict(x_test)

# CNN Autoencoderの定義
def define_model():
    input_img = Input(shape=(3, 256, 256))

    # 3x256x256 -> 8x128x128
    encoded = Conv2D(8, (3, 3), activation='relu', padding='same', data_format='channels_first')(input_img)
    encoded = MaxPooling2D(pool_size=(2, 2), padding='same', data_format='channels_first')(encoded)

    # 8x128x128 -> 8x64x64
    #(省略)

    # 8x64x64 -> 8x32x32
    #(省略)

    # 8x64x64 -> 8x16x16
    encoded = Conv2D(8, (3, 3), activation='relu', padding='same', data_format='channels_first')(encoded)
    # ★この層を後ほど出力したいので、nameを追加★
    encoded = MaxPooling2D(pool_size=(2, 2), padding='same', data_format='channels_first', name='encoder_layer')(encoded)

    # 8x16x16 -> 8x32x32
    #(省略)

    # 8x32x32 -> 8x64x64
    #(省略)

    # 8x64x64 -> 8x128x128
    #(省略)

    # 8x128x128 -> 8x256x256
    decoded = Conv2D(8, (3, 3), activation='relu', padding='same', data_format='channels_first')(decoded)
    decoded = UpSampling2D(size=(2, 2), data_format='channels_first')(decoded)

    # 8x256x256 -> 3x256x256
    decoded = Conv2D(8, (3, 3), activation='sigmoid', padding='same', data_format='channels_first')(decoded)

    # モデルの定義
    model = Model(inputs=input_img, outputs=decoded)
    model.compile(optimizer='adam', loss='binary_crossentropy')

    return model

参照:
https://keras.io/ja/getting-started/faq/#keras-model
http://nobunaga.hatenablog.jp/entry/2017/10/18/000827

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?