#はじめに
本格的にKerasを始めようと思い、まずはXORの予測モデルを構築するコードを書いてみた。将来的にGraph Convolutional Networksなど複雑なモデルも作りたいと思っているため、Functional APIから始めてみる。
#ソース
こんな感じ。
BatchNormalizationをつけないと、局所解に陥りやすかったのでつけている。最終層はlinear層にしている事例が多かったが、0, 1のクラス分類問題であり、納得がいかなかったので、sigmoid関数とした。中間層2層のユニット数は、それぞれ8としてみた。
sample.py
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Input, Dropout, BatchNormalization
import numpy as np
def main():
x_input = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y_input = np.array([[0], [1], [1], [0]])
input_tensor = Input(shape=(x_input.shape[1]))
# x = Dense(units=32, activation="tanh", kernel_initializer='random_normal')(input_tensor)
x = Dense(units=8, activation="relu", kernel_initializer='random_normal', use_bias=True)(input_tensor)
x = BatchNormalization()(x)
x = Dropout(0.1)(x)
x = Dense(units=8, activation="relu", kernel_initializer='random_normal', use_bias=True)(x)
kernel_initializer='random_normal', use_bias=False)(x)
output_layer = Dense(units=1, activation='sigmoid', kernel_initializer='random_normal', use_bias=False)(x)
model = Model(input_tensor, output_layer)
model.compile(loss='mse', optimizer='sgd', metrics=['accuracy'])
model.summary()
# 学習
model.fit(x_input, y_input, nb_epoch=2000, batch_size=2, verbose=2)
# 予測
print(model.predict(np.array([[0, 0]])))
print(model.predict(np.array([[1, 0]])))
print(model.predict(np.array([[0, 1]])))
print(model.predict(np.array([[1, 1]])))
if __name__ == "__main__":
main()
モデル概要
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 2)] 0
_________________________________________________________________
dense (Dense) (None, 8) 24
_________________________________________________________________
batch_normalization (BatchNo (None, 8) 32
_________________________________________________________________
dropout (Dropout) (None, 8) 0
_________________________________________________________________
dense_1 (Dense) (None, 8) 72
_________________________________________________________________
dense_2 (Dense) (None, 1) 8
=================================================================
Total params: 136
Trainable params: 120
Non-trainable params: 16
#予測結果
[[0.09092495]]
[[0.9356866]]
[[0.90092343]]
[[0.08152929]]
#おわりに
今回手始めにやってみたが、パラメータチューニング、可視化など色々試してみたい。