LoginSignup
376
349

More than 5 years have passed since last update.

Kerasでちょっと難しいModelやTrainingを実装するときのTips

Last updated at Posted at 2017-04-14

はじめに

※ Keras2 を対象にしています。

Kerasのコードはシンプルでモジュール性が高いのでシンプルに記述可能で、理解しやすく使いやすいです。
ただし、標準で用意されている以外のLayerや学習をさせようとすると、あまりサンプルがなくてどう書いていいかわからなくなることが多いです。

最近いくつか変わったModelを書いた時に学んだTipsを備忘録も兼ねて共有します。

目次

  • Functional APIを使おう
  • Weightを共有したい場合は Container を使うと便利
  • 「LayerのOutput」と「生のTensor」は似て非なるもの
  • Lambdaを使った簡易変換は便利
  • カスタムなLoss FunctionはSample別にLossを返す
  • LayerじゃないところからLoss関数に式を追加したい場合
  • 学習時にパラメータを更新しつつLossに反映した場合

Tips

Functional APIを使おう

Kerasには2通りのModelの書き方があります。
Sequential ModelFunctional API Model です。

Sequential Modelは

model = Sequential()
model.add(Dense(32, input_dim=784))
model.add(Activation('relu'))

という書き方で、最初これを見て「Kerasすげーわかりやすい!」と思った方も多いのではないでしょうか。

それとは別に

inputs = Input(shape=(784,))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)
model = Model(input=inputs, output=predictions)

という書き方があります。これは LayerInstance(InputTensor) -> OutputTensor というリズムで書いていく書き方です。

Dense(64, activation='relu')(x) というのが Pythonっぽい言語に慣れていない人には違和感があるかもしれませんが、
Dense(64, activation='relu') の部分で Dense Class の Instance を作って、それにたいして DenseInstance(x) しているだけです。

dense = Dense(64, activation='relu')
x = dense(x)

と意味は同じです。

流れとしては、 入力となるLayer 出力となるLayerを決めて、それを Model Classに渡してあげるということです。
入力となるLayerが実データなら Input classを使って指定しておきます(Placeholderのようなもの)。

ここで意識しておきたいことは、 LayerInstance毎にWeightを保持している ということです。
つまり 同じLayerInstanceを使うとWeightを共有している ということになります。
意図して共有する場合はもちろん、意図しない共有にも気をつけましょう。

この書き方なら、同じOutputTensorを別のLayerに入力することが簡単にできます。
記述量は大して変わらないですし、だんだん慣れてきたらこちらのFunctional API で書く練習をして、今後の難しいModelへ備えておくことをおすすめします。

複数LayerのWeightを共有したい場合は Container を使うと便利

異なる入力Layer と 異なる出力Layer を持つが、中身のNetworkとWeightは共有したい、という場合があります。
その場合は、 Container クラスで まとめておくと 取り回しがよくなります。
ContainerLayer のサブクラスなので、Layerと同様に 同じContainerInstanceを使うとWeightを共有している ことになります。

例えば、

inputs = Input(shape=(784,))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)
shared_layers = Container(inputs, predictions, name="shared_layers")

というような shared_layers はあたかも一つのLayerのように扱えます。

Container 自体は基本的には独自のWeightを持たず、あくまで他のLayerを束ねる役割を果たします。

逆にWeightを共有したくない場合は、Containerを共有せずに、個別にLayerInstanceを連ねるようにするしないといけません。

「LayerのOutput」と「生のTensor」は似て非なるもの

独自の計算やTensor変換を書いているとしばしば、

TypeError: ('Not a Keras tensor:', Elemwise{add,no_inplace}.0)

というエラーを目にします。

これは大抵、LayerInstance の 入力に「他のLayerのOutput」ではなく「生のTensor」を入れてしまう場合に起こります。
例えば、

from keras import backend as K

inputs = Input((10, ))
x = K.relu(inputs * 2 + 1)
x = Dense(64, activation='relu')(x)

などとするとそういうことが起こります。
よくはわかりませんが、LayerのOutputは KerasTensorという内部的にShapeを持ったObjectで、K.hogehoge などの計算結果とは異なるもののようです。

そのような場合は、下記のLambdaを使ってあげると上手くいきます(強引に _keras_shape を埋めようとしないほうが無難です ^^;)。

Lambdaを使った簡易変換は便利

例えば、 10要素のVectorを前半の5個、後半の5個に分けたいとします。
前述の通り

inputs = Input((10, ))
x0_4 = inputs[:5]
x5_9 = inputs[5:]
d1 = Dense(10)(x0_4)
d2 = Dense(10)(x5_9)

などとするとエラーになります。

そこで

inputs = Input((10, ))
x0_4 = Lambda(lambda x: x[:, :5], output_shape=(5, ))(inputs)
x5_9 = Lambda(lambda x: x[:, 5:], output_shape=lambda input_shape: (None, int(input_shape[1]/2), ))(inputs)
d1 = Dense(10)(x0_4)
d2 = Dense(10)(x5_9)

というように Lambda classでWrapしてあげるとうまくいきます。
ここでは少しポイントがあります。

Lambdaの内部ではSample次元を含めたTensor計算式を書く必要がある

Kerasでは一貫して最初の次元がSample次元(batch_sizeの次元)になっています。
LambdaなどのLayerを実装するときには、内部ではそのSample次元を含めた計算式を書きます。
なので lambda x: x[:5] ではなく lambda x: x[:, :5] と書く必要があります。

入力のShapeと出力のShapeが異なるときはoutput_shapeを指定する

output_shape は入出力のShapeが同じ場合は省略できますが、異なる場合は必ず指定します。
output_shapeの引数はTupleやFunctionが指定可能ですが、TupleのときはSample次元を含めないFunctionのときはSample次元を含める ようにします。
Functionのときは、基本的にSample次元はNoneにしておけばOKです。
また Functionで指定する場合の引数であるinput_shape ですが、これにはSample次元が含まれているので注意が必要です。

  • OK: Lambda(lambda x: x[:, :5], output_shape=(5, ))(inputs)
  • NG: Lambda(lambda x: x[:, :5], output_shape=(None, 5))(inputs) # inputsが1次元のときは実はOKだけど、2次元になるとNGなので、含めないとおぼえておくと良いです。
  • NG: Lambda(lambda x: x[:, 5:], output_shape=lambda input_shape: (int(input_shape[1]/2), ))(inputs)
  • OK: Lambda(lambda x: x[:, 5:], output_shape=lambda input_shape: (None, int(input_shape[1]/2)))(inputs)

カスタムなLoss FunctionはSample別にLossを返す

ModelのcompileメソッドでLoss関数を指定することができ、自分でカスタムしたLoss関数も指定可能です。
Functionの形状ですが、 y_true, y_pred の2つを引数にとって、Sample数の数だけ数値を返すようにします。
例えば以下のようになります。

def generator_loss(y_true, y_pred):  # y_true's shape=(batch_size, row, col, ch)
    return K.mean(K.abs(y_pred - y_true), axis=[1, 2, 3])

[追記:20170802]

このKerasでLSGAN書くで、

ウェイトやマスクをかけるために用意されている関数がそうなっているだけで、逆にそれらを使わないなら今回のようにサンプルをまたいで計算しちゃっても良いんじゃないか

という指摘があり、確かにその通りだと思います。
従って、他で使う予定がなく、sample_weight などを使う必要がなければ、1つのLoss値を返しても問題ないでしょう。

LayerじゃないところからLoss関数に式を追加したい場合

Layerから渡したい場合は、素直にLayer#add_loss を呼び出せば良いですが、Layerではないところから渡すのは少しむずかしいです(というか正しいやり方がわかりません)。

Loss Function以外のLossの計算式は、Model Instanceのcompileが実行されるタイミングで、Model#losses により各Layerから収集されます(regularizerなどから)。
つまり、なんとかしてここに渡せば良いわけです。
例えば、 ContainerModel を継承して、 #losses を Overriedeするとかするとなんとか渡せます。
VATModel を作ったときはこの方法で渡しました。

学習時にパラメータを更新しつつLossに反映した場合

学習時のLoss計算を以前のLoss結果を反映した値にしたい場合があります。
例えば、下記のようなBEGANの計算の DiscriminatorLoss のような場合です。
https://github.com/mokemokechicken/keras_BEGAN/blob/master/src/began/training.py#L104

学習時のパラメータ更新は、
Model Instanceのcompileが実行されるタイミングで、Model#updatesによって渡されるパラメータ情報が使われます。
Loss FunctionからそのModel#updatesにデータを渡す手段が通常では存在しないので(たぶん)、少し工夫します。

  • Loss Functionではなく Loss Class(Instance) にして、パラメータを保持させる
    • __name__ という属性が必須のようなので注意する
  • パラメータは K.variable で保持しておく
  • __call__()の部分に通常のLoss Functionを定義しておく
  • __call__()の中で、K.update() を使って「パラメータ更新用Object」を生成しておき、それを 配列self.updates に追加する
  • その updates を Modelが拾いに行くようにする

と考えて、以下のようにすると一応可能です。

class DiscriminatorLoss:
    __name__ = 'discriminator_loss'

    def __init__(self, lambda_k=0.001, gamma=0.5):
        self.lambda_k = lambda_k
        self.gamma = gamma
        self.k_var = K.variable(0, dtype=K.floatx(), name="discriminator_k")
        self.m_global_var = K.variable(0, dtype=K.floatx(), name="m_global")
        self.loss_real_x_var = K.variable(0, name="loss_real_x")  # for observation
        self.loss_gen_x_var = K.variable(0, name="loss_gen_x")    # for observation
        self.updates = []

    def __call__(self, y_true, y_pred):  # y_true, y_pred shape: (BS, row, col, ch * 2)
        data_true, generator_true = y_true[:, :, :, 0:3], y_true[:, :, :, 3:6]
        data_pred, generator_pred = y_pred[:, :, :, 0:3], y_pred[:, :, :, 3:6]
        loss_data = K.mean(K.abs(data_true - data_pred), axis=[1, 2, 3])
        loss_generator = K.mean(K.abs(generator_true - generator_pred), axis=[1, 2, 3])
        ret = loss_data - self.k_var * loss_generator

        # for updating values in each epoch, use `updates` mechanism
        # DiscriminatorModel collects Loss Function's updates attributes
        mean_loss_data = K.mean(loss_data)
        mean_loss_gen = K.mean(loss_generator)

        # update K
        new_k = self.k_var + self.lambda_k * (self.gamma * mean_loss_data - mean_loss_gen)
        new_k = K.clip(new_k, 0, 1)
        self.updates.append(K.update(self.k_var, new_k))

        # calculate M-Global
        m_global = mean_loss_data + K.abs(self.gamma * mean_loss_data - mean_loss_gen)
        self.updates.append(K.update(self.m_global_var, m_global))

        # let loss_real_x mean_loss_data
        self.updates.append(K.update(self.loss_real_x_var, mean_loss_data))

        # let loss_gen_x mean_loss_gen
        self.updates.append(K.update(self.loss_gen_x_var, mean_loss_gen))

        return ret


class DiscriminatorModel(Model):
    """Model which collects updates from loss_func.updates"""

    @property
    def updates(self):
        updates = super().updates
        if hasattr(self, 'loss_functions'):
            for loss_func in self.loss_functions:
                if hasattr(loss_func, 'updates'):
                    updates += loss_func.updates
        return updates


discriminator = DiscriminatorModel(all_input, all_output, name="discriminator")
discriminator.compile(optimizer=Adam(), loss=DiscriminatorLoss())

さいごに

たぶん、最後の2つはしばらくしたら不要になるようなTipsかもしれないですね。
Kerasはソースコードも追いやすいので、意外となんとかなります。
あと、 Keras2になってエラーメッセージがよりわかりやすくなったので、Debugするときには助かっています。

376
349
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
376
349