61
46

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Kerasで分岐・結合を持つNNを素早く書く

Last updated at Posted at 2017-05-07

#環境
Python 3.5.2
Keras 2.0.4
tensorflow-gpu 1.1.0

#分岐も簡単に書きたい
Sequentialモデルは

from keras.models import Sequential
from keras.layers import InputLayer, Dense, Activation
model = Sequential()
model.add(InputLayer(input_shape=(32,)))
model.add(Dense(32, activation='relu'))
model.add(Dense(16, activation='relu'))
#...

簡単に分岐のないモデルが書ける!
一方で複雑な構造記述で有利なFunctional APIは

from keras.layers import Input, Dense
_input = Input((32,))
x = Dense(32, activation='relu')(_input)
x = Dense(16, activation='relu')(x)
#...

initのカッコとcallのカッコが並んでいたり、右から左へ積んでいたり、
分岐する度に一時変数が必要だったり…
Sequentialモデルのようにサクッと書きたい。

ということで

#関数定義

def build(_in, *nodes):
    x = _in
    for node in nodes:
        if callable(node):
            x = node(x)
        elif isinstance(node, list):
            x = [build(x, branch) for branch in node]
        elif isinstance(node, tuple):
            x = build(x, *node)
        else:
            x = node
    return x

分岐はリスト、Sequentialな積み上げはタプルで表現して
この関数に突っ込みます。

#使用例
追加中...
##分岐→和→分岐→積
example_1.png

from keras.models import Model
from keras.layers import Input, Dense, Add, Multiply
def example_1():
    _input = Input((10,))
    _output = build(
        _input,
        Dense(10),
        [Dense(11, activation='relu'), Dense(11, activation='relu')],
        Add(),
        [Dense(12, activation='relu'), Dense(12, activation='relu')],
        Multiply(),
        )
    model = Model(_input, _output)
    return model

##2入力2出力
このモデル
https://keras.io/getting-started/functional-api-guide/#multi-input-and-multi-output-models
を書いてみます。
example_multi_input_and_multi_output.png

from keras.models import Model
from keras.layers import Input, Dense, Embedding, LSTM, Concatenate
def example_multi_input_and_multi_output():
    main_input = Input(shape=(100,), dtype='int32', name='main_input')
    auxiliary_input = Input(shape=(5,), name='aux_input')
    outputs = build(
        main_input,
        Embedding(output_dim=512, input_dim=10000, input_length=100),
        LSTM(32),
        [Dense(1, activation='sigmoid', name='aux_output'),
         ([auxiliary_input, lambda x: x],
          Concatenate(),
          Dense(64, activation='relu'),
          Dense(64, activation='relu'),
          Dense(64, activation='relu'),
          Dense(1, activation='sigmoid', name='main_output')
         )
        ]
    )
    model = Model([main_input, auxiliary_input], outputs)
    return model

##ResNet
example_residual_connection.png

def example_residual_connection():
    _input = Input(shape=(256, 256, 3))
    _output = build(
        _input,
        [(Conv2D(3, (3, 3), padding='same'),
          Activation('relu'),
          Conv2D(3, (3, 3), padding='same')),
         lambda x: x],
        Add(),
        Activation('relu')
    )
    model = Model(_input, _output)
    return model
61
46
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
61
46

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?