Python
Chainer

(Python) Chainerで、XOR関数作ってみる

(Python)

Chainer使ってみた。

ニューラルネットで、XOR関数を再現。
(入力層2、隠れ層3、出力層1)

xor_test.py
# coding: UTF-8
import numpy as np

from chainer import Variable
from chainer import optimizers
from chainer import Chain
import chainer.functions as F
import chainer.links as L

class MyChain(Chain):
    def __init__(self):
        super(MyChain, self).__init__(
            l1 = L.Linear(2,3),
            l2 = L.Linear(3,1),
        )
    def __call__(self, x, y):
        return F.mean_squared_error(self.fwd(x),y)

    def fwd(self,x):
        h1 = F.sigmoid(self.l1(x))
        h2 = self.l2(h1)
        return h2

model = MyChain()

optimizer = optimizers.Adam()
optimizer.setup(model)

loss_val = 100
ep = 0
nep = 10000

while loss_val > 1e-5:

    x = Variable(np.array([[0, 0], [1, 1], [0, 1], [1, 0]],dtype=np.float32))
    y = Variable(np.array([[0],    [0],    [1],    [1]   ],dtype=np.float32))

    model.cleargrads()
    loss = model(x,y)
    loss.backward()
    optimizer.update()

    if ep % 1000 == 0:
        loss_val = loss.data
        print 'Now learning... ->   ', "%.5f" % loss.data

    if ep >= nep:
        break

    ep += 1


#ここから下は、学習後に、テスト
print '\n'

xt = Variable(np.array([[1, 0]],dtype=np.float32))
yl = model.fwd(xt)
print 'test:1 ', xt[0] ,  ' --> ' , "%.3f" % yl.data[0]

xt = Variable(np.array([[0, 0]],dtype=np.float32))
yl = model.fwd(xt)
print 'test:2 ', xt[0] ,  ' --> ' , "%.3f" % yl.data[0]

xt = Variable(np.array([[1, 1]],dtype=np.float32))
yl = model.fwd(xt)
print 'test:3 ', xt[0] ,  ' --> ' , "%.3f" % yl.data[0]

xt = Variable(np.array([[0, 1]],dtype=np.float32))
yl = model.fwd(xt)
print 'test:4 ', xt[0] ,  ' --> ' , "%.3f" % yl.data[0]

実行結果は以下。

Now learning... ->    0.89796
Now learning... ->    0.24670
Now learning... ->    0.22884
Now learning... ->    0.15288
Now learning... ->    0.02969
Now learning... ->    0.00052
Now learning... ->    0.00000


test:1  variable([ 1.  0.])  -->  1.000
test:2  variable([ 0.  0.])  -->  0.000
test:3  variable([ 1.  1.])  -->  0.001
test:4  variable([ 0.  1.])  -->  0.999

Process finished with exit code 0