9
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

KerasのfunctionalAPIの使い方

Last updated at Posted at 2018-10-31

#KerasのfunctionalAPIの使い方
##はじめに
自分が初心者なので,初心者向けです。
おかしいところや間違っているところなどがあれば,指摘していただけたら嬉しいです……!
こちらの記事の方が参考になるかもしれません。
Kerasはfunctional APIもきちんと理解しよう

またはKerasのチュートリアルも参考になります。
https://keras.io/ja/getting-started/functional-api-guide/

##functionalAPIを使うことの利点
NN.png

こんなネットワークが簡単に作れます。

inputs_1 = Input(shape=(20,), dtype='float32')
inputs_2 = Input(shape=(20,), dtype='float32')
inputs_3 = Input(shape=(20,), dtype='float32')

hidden1_1 = Dense(10, activation='relu')(inputs_1)
hidden1_2 = Dense(10, activation='relu')(inputs_2)
hidden1_3 = Dense(10, activation='relu')(inputs_3)

hidden1_m = Multiply()([hidden1_1, hidden1_2, hidden1_3])

hidden2 = Dense(10, activation='relu')(hidden1_m)

hidden2 = Dense(20, activation='relu')(hidden1_m)

predictions = Dense(7, activation='softmax')(hidden2)

model = Model(inputs=inputs, outputs=predictions)

model.compile(optimizer=RMSprop(), loss='categorical_crossentropy', metrics=['accuracy'])

model.fit([x_train[0], x_train[1], x_train[2]], t_train, epochs=10000, verbose=1, batch_size=100)

こんな感じで実装できます。
もう少し詳しく見ていきます。

##複数の入力層を作る

inputs_1 = Input(shape=(20,), dtype='float32')
inputs_2 = Input(shape=(20,), dtype='float32')
inputs_3 = Input(shape=(20,), dtype='float32')

こんな風にして入力層は増やせます。

##どこの層に結合させるか
中間層から出力層はこのようにして,どこの層と結合させるかを指定できます。

hidden1_1 = Dense(10, activation='relu')(inputs_1)

右の(inputs_1)が結合する層です。

##層の合体
層を結合させるには,このように結合させた層をリストで渡します。

hidden1_m = Multiply()([hidden1_1, hidden1_2, hidden1_3])

もちろん,次のように書いてもOKです。

hiiden1 = [hidden1_1, hidden1_2, hidden1_3]
hidden1_m = Multiply()(hidden1)

##複数の入力のモデル作成

model = Model(inputs=[inputs_1, inputs_2, inputs_3], outputs=predictions)

こういう感じで,リストで渡します。
もちろん,こちらもこれでOK。

inputs = [inputs_1, inputs_2, inputs_3]
model = Model(inputs=inputs, outputs=predictions)

##複数の入力の学習
入力する学習データのリストを与えます。

model.fit([x_train[0], x_train[1], x_train[2]], t_train, epochs=10000, verbose=1, batch_size=700)

##結論
こういう分岐するネットワークもfunctionalAPIなら簡単にかけます!

9
4
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?