Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
1
Help us understand the problem. What is going on with this article?

More than 3 years have passed since last update.

@cvusk

Kerasで分岐を書いてみる その2

Kerasで分岐2

前回(※)Keras.backendのswitchを紹介しましたが、もうちょい触ってみました。
以下記事に書いたコードにちょっとバリエーションを追加したものになりますので、まずは以下をご参照ください。
https://qiita.com/cvusk/items/6955ae3bc802c28c3b66

※ アドベントカレンダー的には日程が前後してしまいましたが・・・。

なお、コード全文は以下にあります。
https://github.com/shibuiwilliam/kerasswitch

switchの中に直接レイヤ定義を書く

switch文の中にレイヤ定義を直接書くことも可能です。

ここで書いているモデルは、AndゲートとOrゲートを同居させたモデルで、switchでAndとOrを判定しています。
入力の形式は(8,3)です。
入力値の第3項は、Andの場合は0、Orの場合は1に設定し、switchはこの第3項を見て、第1,2項のみを次のレイヤの入力にします。


# and data and or data
X = np.array([[0,0,0],[0,1,0],[1,0,0],[1,1,0], #and
              [0,0,1],[0,1,1],[1,0,1],[1,1,1]]) #or
Y = np.array([0,0,0,1, #and
              0,1,1,1]) #or
Y = keras.utils.to_categorical(Y, 2)

# input layer
inputs = Input(shape=X.shape[1:],name="input")

# switch cases for and and or
x_switch = Lambda(lambda x: K.switch(x[:,2]<1,
                                     Dense(32, activation="relu", name="and1")(x[:,:2]), # !
                                     Dense(32, activation="relu", name="or2") # !!
                                     (Dense(16, activation="relu", name="or1")(x[:,:2]))), # !!
                  output_shape=(32,))(inputs)

# output layer
outputs = Dense(2, activation="softmax", name="softmax")(x_switch)

model = Model(inputs, outputs)

model.compile(loss='categorical_crossentropy',
              optimizer=Adam(),
              metrics=['accuracy'])
model.fit(X, Y,
          batch_size=4,
          epochs=20000,
          shuffle=True)

この場合の注意点は、expression部分に複数レイヤを定義する場合、冗長かつ読みづらくなることです。
1レイヤ(# !部分)だけであれば良いのですが、2レイヤ以上(# !!部分)だと、レイヤを変数で取ることができないので、Lambdaの出力から先に書いていくことになります。
つまり、Lambdaの出力に対して入力されるレイヤを()に入れていく方式になり、通常のレイヤ記述の順番と逆になります。


# !!部分の通常の記述
a = Dense(16, activation="relu", name="or1")(x[:,:2])
a = Dense(32, activation="relu", name="or2")(a)

複数の入力を取る

上記では入力の第3項にAndかOrの判定フラグを入れていましたが、このフラブを別の入力テンソルにして、複数入力をとるモデルにすることも可能です。


# and data and or data
X2 = np.array([[0,0],[0,1],[1,0],[1,1], # and
              [0,0],[0,1],[1,0],[1,1]]) # or
X2_flag = np.array([[0],[0],[0],[0], # and
                   [1],[1],[1],[1]]) # or
Y2 = np.array([0,0,0,1, #and
               0,1,1,1]) #or
Y2 = keras.utils.to_categorical(Y2, 2)

# input layer
inputs = Input(shape=X2.shape[1:], name="input")
flags = Input(shape=X2_flag.shape[1:], name="flags")

indense = Dense(4, activation="relu", name="dense1")(inputs)

# and layers and or layers
def andGate(inputs):
    andDense = Dense(16, activation="relu", name="and1")(inputs)
    andDense = Dense(32, activation="relu", name="and2")(andDense) 
    return andDense
def orGate(inputs):
    orDense = Dense(32, activation="sigmoid", name="or1")(inputs)     
    return orDense

# switch cases for and and or
x_switch = Lambda(lambda x: K.switch(x[1][:,0]<1,
                                     andGate(x[0]),
                                     orGate(x[0])),
                  output_shape=(32,))([indense, flags])

# output layer
outputs = Dense(2, activation="softmax", name="softmax")(x_switch)

model = Model([inputs, flags], outputs)

model.compile(loss='categorical_crossentropy',
              optimizer=Adam(),
              metrics=['accuracy'])

model.fit([X2, X2_flag], Y2,
          batch_size=4,
          epochs=20000,
          shuffle=True)

入力はinputsとflagsにして、inputsをdense1を介したあとにswitchするという方式です。
switchではflagsをconditionとしつつ、andGateとorGateにはdense1のみが入力されていく、というモデルになっています。

1
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
1
Help us understand the problem. What is going on with this article?