Help us understand the problem. What is going on with this article?

Kerasで分岐・結合を持つNNを素早く書く

More than 1 year has passed since last update.

環境

Python 3.5.2
Keras 2.0.4
tensorflow-gpu 1.1.0

分岐も簡単に書きたい

Sequentialモデルは

from keras.models import Sequential
from keras.layers import InputLayer, Dense, Activation
model = Sequential()
model.add(InputLayer(input_shape=(32,)))
model.add(Dense(32, activation='relu'))
model.add(Dense(16, activation='relu'))
#...

簡単に分岐のないモデルが書ける!
一方で複雑な構造記述で有利なFunctional APIは

from keras.layers import Input, Dense
input = Input((32,))
x = Dense(32, activation='relu')(input)
x = Dense(16, activation='relu')(x)
#...

initのカッコとcallのカッコが並んでいたり、右から左へ積んでいたり、
分岐する度に一時変数が必要だったり…
Sequentialモデルのようにサクッと書きたい。

ということで

関数定義

def build(input, *nodes):
    x = input
    for node in nodes:
        if callable(node):
            x = node(x)
        elif isinstance(node, list):
            x = [build(x, branch) for branch in node]
        elif isinstance(node, tuple):
            x = build(x, *node)
        else:
            x = node
    return x

分岐はリスト、Sequentialな積み上げはタプルで表現して
この関数に突っ込みます。

使用例

追加中...

分岐→和→分岐→積

example_1.png

from keras.models import Model
from keras.layers import Input, Dense, Add, Multiply
def example_1():
    input = Input((10,))
    output = build(
        input,
        Dense(10),
        [Dense(11, activation='relu'), Dense(11, activation='relu')],
        Add(),
        [Dense(12, activation='relu'), Dense(12, activation='relu')],
        Multiply(),
        )
    model = Model(input, output)

2入力2出力

このモデル
https://keras.io/getting-started/functional-api-guide/#multi-input-and-multi-output-models
を書いてみます。
example_multi_input_and_multi_output.png

from keras.models import Model
from keras.layers import Input, Dense, Embedding, LSTM, Concatenate
def example_multi_input_and_multi_output():
    main_input = Input(shape=(100,), dtype='int32', name='main_input')
    auxiliary_input = Input(shape=(5,), name='aux_input')
    outputs = build(
        main_input,
        Embedding(output_dim=512, input_dim=10000, input_length=100),
        LSTM(32),
        [Dense(1, activation='sigmoid', name='aux_output'),
         ([auxiliary_input, lambda x: x],
          Concatenate(),
          Dense(64, activation='relu'),
          Dense(64, activation='relu'),
          Dense(64, activation='relu'),
          Dense(1, activation='sigmoid', name='main_output')
         )
        ]
    )
    model = Model([main_input, auxiliary_input], outputs)

ResNet

example_residual_connection.png

def example_residual_connection():
    input = Input(shape=(256, 256, 3))
    output = build(
        input,
        [(Conv2D(3, (3, 3), padding='same'),
          Activation('relu'),
          Conv2D(3, (3, 3), padding='same')),
         lambda x: x],
        Add(),
        Activation('relu')
    )
    model = Model(input, output)
Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away