Help us understand the problem. What is going on with this article?

Kerasで分岐を書いてみる その3 Case文を書いてみる

More than 1 year has passed since last update.

Kerasで分岐を書いてみる その3 Case文を書いてみる

Kerasでニューラルネットワークに分岐を組み込むことができます。
たとえば前レイヤーの出力が0か1の場合、0の場合はAレイヤー、1の場合はBレイヤーというようにルールベースでレイヤーを選択させることができます。
Kerasで用意されているKeras.backend.switch()を使うのですが、このswitch()でできる分岐は2分岐です。0か1か、以上か以下か、=か≠かという条件のみが可能です。

2018-03-11_0.png

※ Keras.backend.switchの使い方は以下で投稿しています。
Kerasで分岐を書いてみる
Kerasで分岐を書いてみる その2

しかし2分岐だけでは足りないケースもあると思うので、switch()を使って複数の分岐を書いてみました。

やりたいこと

  • Kerasで任意のレイヤーの出力を条件に、次のレイヤーをルールベースで分岐する。
  • 分岐先の選択肢は2個以上ある。
  • 簡単に書く。
  • 以下のような分岐をできるようにする。

2018-03-11_1.png

方針

そもそもKeras.backend.switch()で2分岐できるので、これを組み合わせれば複数の分岐にできます。
2値分類を組み合わせて多値分類を実装するようなものです。

ここでは例として論理回路をニューラルネットワークで書いてみます。
まずはAND回路とOR回路です。
入力に0 or 1の2値と、3値目にANDフラグ、ORフラグをとります。
入力の3値目は0ならAND回路、1ならOR回路です。
例:[0,1,0]の場合、最初2値の0,1は真理値、3値目は0なのでAND回路
  [1,0,1]の場合、最初2値の1,0は真理値、3値目は1なのでOR回路

import numpy as np

X = np.array([[0,0,0],[0,1,0],[1,0,0],[1,1,0],  #and
              [0,0,1],[0,1,1],[1,0,1],[1,1,1]]) #or
Y = np.array([0,0,0,1,  #and
              0,1,1,1]) #or
Y = keras.utils.to_categorical(Y, 2)

これをKeras.backend.switch()を使って分岐させると以下のようになります。

import keras
from keras.models import Model
from keras.layers import Dense, Input, Lambda
from keras.optimizers import Adam
import keras.backend as K
import numpy as np

X = np.array([[0,0,0],[0,1,0],[1,0,0],[1,1,0],  #and
              [0,0,1],[0,1,1],[1,0,1],[1,1,1]]) #or
Y = np.array([0,0,0,1,  #and
              0,1,1,1]) #or
Y = keras.utils.to_categorical(Y, 2)

inputs = Input(shape=X.shape[1:],name="input")

# and layers and or layers
def andGate(inputs):
    andDense = Dense(16, activation="relu", name="and1")(inputs)
    andDense = Dense(32, activation="relu", name="and2")(andDense) 
    return andDense
def orGate(inputs):
    orDense = Dense(32, activation="sigmoid", name="or1")(inputs)     
    return orDense

# switch cases for "and" gate and "or" gate
x_switch = Lambda(lambda x: K.switch(x[:,2]<1,
                                     andGate(x[:,:2]),
                                     orGate(x[:,:2])),
                  output_shape=(32,))(inputs)

# output layer
outputs = Dense(2, activation="softmax", name="softmax")(x_switch)

LambdaレイヤーでK.switch()をとります。
K.switch()の入力は(条件、条件がTrueのときの処理、条件がFalseのときの処理)になっています。

#        条件   Trueのときの処理   Falseのときの処理
K.switch(x[:,2]<1, andGate(x[:,:2]), orGate(x[:,:2])),

このK.switch()は標準で2分岐しかできませんが、2分岐を重ねていけば2以上の分岐も可能になります。
ここでは「条件がFalseのときの処理」に、さらにK.switch()を加えていくことで、複数分岐を書いていきます。

より多くの分岐を書く

次はAND回路、OR回路に加えて、XOR回路、NOR回路、NAND回路を入出力に追加してみます。
入力の3値目が条件で、以下のような対応になっています。
0: AND
1: OR
2: XOR
3: NOR
4: NAND

# and, or, xor, nor, nand
X = np.array([[0,0,0],[0,1,0],[1,0,0],[1,1,0], #and
              [0,0,1],[0,1,1],[1,0,1],[1,1,1], #or
              [0,0,2],[0,1,2],[1,0,2],[1,1,2], #xor
              [0,0,3],[0,1,3],[1,0,3],[1,1,3], #nor
              [0,0,4],[0,1,4],[1,0,4],[1,1,4]]) #nand
Y = np.array([0,0,0,1, #and
              0,1,1,1, #or
              0,1,1,0, #xor
              1,0,0,0, #nor
              1,1,1,0]) #nand

# and layers and or layers
def andGate(inputs=inputs):
    andDense = Dense(32, activation="relu", name="and1")(inputs) 
    return andDense
def orGate(inputs=inputs):
    orDense = Dense(32, activation="sigmoid", name="or1")(inputs)     
    return orDense
def xorGate(inputs=inputs):
    xorDense = Dense(32, activation="tanh", name="xor1")(inputs)
    return xorDense
def norGate(inputs=inputs):
    norDense = Dense(32, activation="sigmoid", name="nor1")(inputs)
    norDense = Dense(32, activation="relu", name="nor2")(norDense)
    return norDense
def nandGate(inputs=inputs):
    nandDense = Dense(16, activation="relu", name="nand1")(inputs)
    nandDense = Dense(32, activation="relu", name="nand2")(nandDense)
    return nandDense

# stack of K.switches
x_case = Lambda(lambda x: K.switch(K.equal(x[:,2],0),
                                   andGate(x[:,:2]),
                                   K.switch(K.equal(x[:,2],1),
                                            orGate(x[:,:2]),
                                            K.switch(K.equal(x[:,2],2),
                                                     xorGate(x[:,:2]),
                                                     K.switch(K.equal(x[:,2],3),
                                                              norGate(x[:,:2]),
                                                              nandGate(x[:,:2]))))),
                  output_shape=(32,))(inputs)

もうちょいスマートに

しかしこれだと読みにくいし書きにくいので、もっとスマートに書けるようにcase()という関数にしてみました。

def case(case_and_proc, default):
    """
    case_and_proc: pairs of case and process in list type or dict type
    default: default process. i.e. executed whenever the condition doesn't match any case
    """
    print(case_and_proc)
    if type(case_and_proc) == list:
        for cnum in range(len(case_and_proc)):
            if cnum == 0:
                print(case_and_proc[cnum][1])
                print(default)
                x = K.switch(case_and_proc[cnum][0], case_and_proc[cnum][1], default)
            else:
                print(case_and_proc[cnum][1])
                x = K.switch(case_and_proc[cnum][0], case_and_proc[cnum][1], x)
    elif type(case_and_proc) == dict:
        for i, (cond, expression) in enumerate(case_and_proc.items()):
            if i == 0:
                print(case_and_proc[cond])
                print(default)
                x = K.switch(cond, expression, default)
            else:
                print(case_and_proc[cond])
                x = K.switch(cond, expression, x)
    return x

# input layer
inputs = Input(shape=X.shape[1:],name="input")

x_case = Lambda(lambda x: case({K.equal(x[:,2],0): andGate(x[:,:2]),
                                K.equal(x[:,2],1): orGate(x[:,:2]),
                                K.equal(x[:,2],2): xorGate(x[:,:2]),
                                K.equal(x[:,2],3): norGate(x[:,:2])},
                               default=nandGate(x[:,:2])),
               output_shape=(32,))(inputs)

# output layer
outputs = Dense(2, activation="softmax", name="softmax")(x_case)
# model
model = Model(inputs, outputs)
model.compile(loss='categorical_crossentropy',
              optimizer=Adam(),
              metrics=['accuracy'])

case()関数はcase_and_procとdefaultを引数にとります。
case_and_proc: 条件と処理のペアで、ListまたはDict形式でとります。
default: 条件に当てはまらない場合のデフォルト処理です。
上記ではcase_and_procにDict形式で{条件:処理}にAND回路、OR回路、XOR回路、NOR回路を羅列し、最後にデフォルト処理としてNAND回路を書いています。

List形式で条件と処理を書くことも可能です。

x_case = Lambda(lambda x: case([(K.equal(x[:,2],0), andGate(x[:,:2])),
                                (K.equal(x[:,2],1), orGate(x[:,:2])),
                                (K.equal(x[:,2],2), xorGate(x[:,:2])),
                                (K.equal(x[:,2],3), norGate(x[:,:2]))],
                               default=nandGate(x[:,:2])),
               output_shape=(32,))(inputs)

List内はSetで(条件と処理)のペアを書きます。

プログラム全文は以下です。
https://github.com/shibuiwilliam/kerasswitch/blob/master/kerascase.ipynb

参考

なお、Tensorflowにはすでにtf.case()が用意されていて、複数分岐を書くことができます。
tf.case

KerasとTensorflowの2分岐は以下です。
Keras.backend.switch()
tf.cond

cvusk
#Python #Golang #C++ #Bash #Linux #MachineLearning #DeepLearning #Keras #Tensorflow #Docker #Kubernetes #AWS #Azure #GCP #SAS #MENSA #Unity #C# #Kotlin #Android #PyTorch #DL4J #ARCore
https://github.com/shibuiwilliam
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away