LoginSignup
0
0

More than 3 years have passed since last update.

Keras Functional APIでXORを実装してみる

Last updated at Posted at 2020-10-15

はじめに

本格的にKerasを始めようと思い、まずはXORの予測モデルを構築するコードを書いてみた。将来的にGraph Convolutional Networksなど複雑なモデルも作りたいと思っているため、Functional APIから始めてみる。

ソース

こんな感じ。
BatchNormalizationをつけないと、局所解に陥りやすかったのでつけている。最終層はlinear層にしている事例が多かったが、0, 1のクラス分類問題であり、納得がいかなかったので、sigmoid関数とした。中間層2層のユニット数は、それぞれ8としてみた。

sample.py
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Input, Dropout, BatchNormalization

import numpy as np


def main():
    x_input = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
    y_input = np.array([[0], [1], [1], [0]])

    input_tensor = Input(shape=(x_input.shape[1]))
    # x = Dense(units=32, activation="tanh", kernel_initializer='random_normal')(input_tensor)
    x = Dense(units=8, activation="relu", kernel_initializer='random_normal', use_bias=True)(input_tensor)
    x = BatchNormalization()(x)
    x = Dropout(0.1)(x)
    x = Dense(units=8, activation="relu", kernel_initializer='random_normal', use_bias=True)(x)
kernel_initializer='random_normal', use_bias=False)(x)
    output_layer = Dense(units=1, activation='sigmoid', kernel_initializer='random_normal', use_bias=False)(x)
    model = Model(input_tensor, output_layer)

    model.compile(loss='mse',  optimizer='sgd', metrics=['accuracy'])
    model.summary()

    # 学習
    model.fit(x_input, y_input, nb_epoch=2000, batch_size=2, verbose=2)

    # 予測
    print(model.predict(np.array([[0, 0]])))
    print(model.predict(np.array([[1, 0]])))
    print(model.predict(np.array([[0, 1]])))
    print(model.predict(np.array([[1, 1]])))


if __name__ == "__main__":
    main()

モデル概要

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 2)]               0
_________________________________________________________________
dense (Dense)                (None, 8)                 24
_________________________________________________________________
batch_normalization (BatchNo (None, 8)                 32
_________________________________________________________________
dropout (Dropout)            (None, 8)                 0
_________________________________________________________________
dense_1 (Dense)              (None, 8)                 72
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 8
=================================================================
Total params: 136
Trainable params: 120
Non-trainable params: 16

予測結果

[[0.09092495]]
[[0.9356866]]
[[0.90092343]]
[[0.08152929]]

おわりに

今回手始めにやってみたが、パラメータチューニング、可視化など色々試してみたい。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0