1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

(Python) Chainerで、XOR関数作ってみる

1
Posted at

(Python)

Chainer使ってみた。

ニューラルネットで、XOR関数を再現。
(入力層2、隠れ層3、出力層1)

xor_test.py
# coding: UTF-8
import numpy as np

from chainer import Variable
from chainer import optimizers
from chainer import Chain
import chainer.functions as F
import chainer.links as L

class MyChain(Chain):
    def __init__(self):
        super(MyChain, self).__init__(
            l1 = L.Linear(2,3),
            l2 = L.Linear(3,1),
        )
    def __call__(self, x, y):
        return F.mean_squared_error(self.fwd(x),y)

    def fwd(self,x):
        h1 = F.sigmoid(self.l1(x))
        h2 = self.l2(h1)
        return h2

model = MyChain()

optimizer = optimizers.Adam()
optimizer.setup(model)

loss_val = 100
ep = 0
nep = 10000

while loss_val > 1e-5:

    x = Variable(np.array([[0, 0], [1, 1], [0, 1], [1, 0]],dtype=np.float32))
    y = Variable(np.array([[0],    [0],    [1],    [1]   ],dtype=np.float32))

    model.cleargrads()
    loss = model(x,y)
    loss.backward()
    optimizer.update()

    if ep % 1000 == 0:
        loss_val = loss.data
        print 'Now learning... ->   ', "%.5f" % loss.data

    if ep >= nep:
        break

    ep += 1


# ここから下は、学習後に、テスト
print '\n'

xt = Variable(np.array([[1, 0]],dtype=np.float32))
yl = model.fwd(xt)
print 'test:1 ', xt[0] ,  ' --> ' , "%.3f" % yl.data[0]

xt = Variable(np.array([[0, 0]],dtype=np.float32))
yl = model.fwd(xt)
print 'test:2 ', xt[0] ,  ' --> ' , "%.3f" % yl.data[0]

xt = Variable(np.array([[1, 1]],dtype=np.float32))
yl = model.fwd(xt)
print 'test:3 ', xt[0] ,  ' --> ' , "%.3f" % yl.data[0]

xt = Variable(np.array([[0, 1]],dtype=np.float32))
yl = model.fwd(xt)
print 'test:4 ', xt[0] ,  ' --> ' , "%.3f" % yl.data[0]

実行結果は以下。

Now learning... ->    0.89796
Now learning... ->    0.24670
Now learning... ->    0.22884
Now learning... ->    0.15288
Now learning... ->    0.02969
Now learning... ->    0.00052
Now learning... ->    0.00000


test:1  variable([ 1.  0.])  -->  1.000
test:2  variable([ 0.  0.])  -->  0.000
test:3  variable([ 1.  1.])  -->  0.001
test:4  variable([ 0.  1.])  -->  0.999

Process finished with exit code 0
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?