2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Kerasのcallback関数で中間層で出力される特徴表現を保存する

Posted at

Kerasのcallback関数で中間層で出力される特徴表現を保存する。

画像の抽象表現の獲得を必要とする機械学習の場合、評価の良いモデルを作るためには中間層でどのような処理が行われているかが肝になってくる。
ここでは、1epochごとに中間層での処理内容を出力するプログラムについて説明する。

Kerasのcallback関数の使い方

epoch毎中間層の処理を出力するするための関数を作成する
https://github.com/keras-team/keras/blob/master/keras/callbacks.py
ここにkeras標準のcallback関数があるので、これらをまねて記述する。

Callbackを継承するクラスを記述る

from tensorflow.keras import callbacks

class VisualizeFullLayers(callbacks.Callback):
    def __init__(self, arg1, arg2, ....):
        self.arg1 = arg1
        self.arg2 = arg2
		 :
		 :

ここではcallbackの一部の機能を使用して、全体的には以下のような構成にする

from tensorflow.keras import callbacks

class VisualizeFullLayers(callbacks.Callback):
    def __init__(self, arg1, arg2, ....):
        self.arg1 = arg1
        self.arg2 = arg2
		 :
		 :
    
	# エポック完了ごとに行う処理
    def on_epoch_end(self, epoch, logs=None):
        pass

	# 学習終了後に行う処理
    def on_train_end(self, epoch, logs=None):
	    pass

各レイヤーごとに出力される特徴量

以下のコードは指定したepoch毎に行う中間層を出力する処理

 # 中間層保存用辞書
full_layers_dict = {}

# 入力データを除いたレイヤー毎に保存の処理を行う
input_layers = ["input_layer0", "input_layer1", .....]
for layer in self.model.layers:
    if layer.name in input_layers:
        continue
	layer_func = K.function(inputs=[self.model.inputs, K.learning_phase()],
                            outputs=self.model.get_layer(layer.name).output)
    # 適当なテスト画像を入力として加える
    layer_features = layer_func([x_test_dict])
    full_layers_dict[layer.name] = layer_features

pickleのバイナリ形式で保存

with open(save_name +\
    "/full_middle_layer_epoch{0:0=3}_valloss{1:.2e}.pickle".format(
        epoch, logs["val_loss"]), mode='wb') as f:
    pickle.dump(full_layers_dict, f)

全体のコード

from tensorflow.keras import callbacks
import tensorflow.keras.backend as K 
import os
import pickle

class VisualizeFullLayers(callbacks.Callback):
    def __init__(self, x_test, save_dir, period):
        self.x_test = x_test
        self.save_dir = save_dir
        self.period = period
    
    def on_epoch_end(self, epoch, logs=None):
        save_name = self.save_dir + ...
        os.makedirs(save_name, exist_ok=True)
        if epoch%self.period==0:
            full_layers_dict = {}
            input_layers = ["input_layer0", "input_layer1", .....]
            for layer in self.model.layers:
                if layer.name in input_layers:
                    continue
                layer_func = K.function(inputs=[self.model.inputs, K.learning_phase()],
                                        outputs=self.model.get_layer(layer.name).output)
                layer_features = layer_func([x_test])
                full_layers_dict[layer.name] = layer_features

            # ついでに入力に使用した画像データも保存。これにより入力から出力までの流れを追える。
            full_layers_dict.update(x_test)
            with open(save_name +\
                "/full_middle_layer_epoch{0:0=3}_valloss{1:.2e}.pickle".format(
                    epoch, logs["val_loss"]), mode='wb') as f:
                pickle.dump(full_layers_dict, f)
            print("\nsave complete")
2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?