LoginSignup
0
1

More than 5 years have passed since last update.

Python で ニューラルネット(XOR回路を作ってみる)

Last updated at Posted at 2018-01-08

Python で ニューラルネット(XOR回路を作ってみる)


XOR(排他的論理和)を、ニューラルネットに学習させる例

<XOR>
True と True で False  (1 & 1 -> 0)
True と False で True   (1 & 0 -> 1)
False と True で True   (0 & 1 -> 1)
False と False で False  (0 & 0 -> 0)

※global変数の多用については、眼をつぶって下さい、、、
※バックプロパゲーション(Backpropagation、誤差逆伝播法)をつかったベーシックな例です

NNetXOR.ps
import random
import math

iINPUT = 2
iHIDDEN = 4
iOUTPUT = 1

iPR = 1000
iMAX_T = 300000
dETA = 2.5
dEPS = 0.00001
dALPHA = 0.92
dBETA = 0.35
dW0 = 0.9

xi = [0 for i in range(iINPUT +1)] 
v =  [0 for i in range(iHIDDEN+1)] 
o =  [0 for i in range(iOUTPUT)]
zeta =  [0 for i in range(iOUTPUT)] 

#------------
def arand():
    r = random.random()
    r = r * 2 *dW0 - dW0
    return r
#------------

w1 = [[arand() for i in range(iINPUT+1)]  for i in range(iHIDDEN)]
w2 = [[arand() for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]

d_w1 = [[0 for i in range(iINPUT+1)]  for i in range(iHIDDEN)]
d_w2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]

pre_dw1 = [[0 for i in range(iINPUT+1)]  for i in range(iHIDDEN)]
pre_dw2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]

#入力する値              
data =  [[0,0],[0,1],[1,0],[1,1],[1,0],[1,1],[0,1],[0,0]]
for i in range(8):
    data[i].append(1)

#教師信号(所謂、模範解答)
t_data =[[0],  [1],  [1],  [0],  [1],  [0],  [1],  [0]]
#お試し用の問題
d_data =[[1,0],[1,1],[0,1],[0,0]]

for i in range(4):
    d_data[i].append(1)

##-----------------------------

def dw_init():
    global d_w1
    global d_w2
    global pre_dw1
    global pre_dw2
    pre_dw1 = d_w1
    pre_dw2 = d_w2
    d_w1 = [[0 for i in range(iINPUT+1)]  for i in range(iHIDDEN)]
    d_w2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]

def sigmoid(u):
    return  1.0 / (1.0 +  math.exp(-dBETA * u));

def xi_set(p):
    global xi
    global zeta
    xi = data[p]
    zeta = t_data[p]

def forward():
    global xi
    global w1
    global v
    global w2
    global o
    for j in range(iHIDDEN):
        sm = 0
        for k in range(iINPUT+1):
            sm += xi[k] * w1[j][k]
        v[j] = sigmoid(sm)
    v[iHIDDEN] = 1.0
    for j2 in range(iOUTPUT):
        sm = 0
        for j in range(iHIDDEN+1):
            sm += v[j] * w2[j2][j]
        o[j2] = sigmoid(sm)

def backward():
    global o
    global zeta
    global v
    global xi
    global d_w1
    global d_w2
    global w2

    delta2 = [0 for i in range(iOUTPUT)]
    delta1 = [0 for i in range(iHIDDEN+1)]
    for i in range(iOUTPUT):
        delta2[i] = dBETA * o[i] * (1-o[i]) * (zeta[i]-o[i])
    for j in range(iHIDDEN):
        sm = 0
        for j2 in range(iOUTPUT):
            sm = sm + w2[j2][j] * delta2[j2]
            delta1[j] = dBETA * v[j] * (1-v[j]) * sm
    for j2 in range(iOUTPUT):
        for j in range(iHIDDEN+1):
            d_w2[j2][j] = d_w2[j2][j] + delta2[j2] * v[j]
    for j in range(iHIDDEN):
        for k in range(iINPUT+1):
            d_w1[j][k] = d_w1[j][k] + delta1[j] * xi[k]

def w_modify():
    for j2 in range(iOUTPUT):
        for j in range(iHIDDEN+1):
            d_w2[j2][j] = dALPHA * dETA * d_w2[j2][j] + dALPHA * pre_dw2[j2][j]
            w2[j2][j] = w2[j2][j] + d_w2[j2][j]
    for j in range(iHIDDEN):
        for k in range(iINPUT+1):
            d_w1[j][k] = dALPHA * dETA * d_w1[j][k] + dALPHA * pre_dw1[j][k]
            w1[j][k] = w1[j][k] + d_w1[j][k]

def calc_error():
    e = 0.0
    for i in range(iOUTPUT):
        e = e +(zeta[i]-o[i])*(zeta[i]-o[i])
    return e

def back_propagation_main():
    global zeta
    e = 0.0
    esum = 0.0
    for t in range(iMAX_T):
        dw_init()
        esum = 0.0
        for p in range(iPATTERNz):
            xi_set(p)
            forward()
            backward()
            esum = esum + calc_error()
        w_modify()
        e = esum #/ (iOUTPUT * iPATTERNz)
        if t % iPR ==0:
            print(t, " / ", iMAX_T, " / ", "{0:.7f}".format(e))
        if(e<dEPS):
            break

def tryTest():
    global xi
    global o
    print("-------------------")
    for p in range(iPATTERNo):
        xi = d_data[p]
        forward()
        print(xi[0],"-",xi[1],"  -> {0:.4f}".format(o[0]))
##-----------------------------
iPATTERNz =8
iPATTERNo =4

print ("---- Start ----")
back_propagation_main()
tryTest()
print("---- End ----")

結果はこちら

---- Start ----
0  /  300000  /  2.1139707
1000  /  300000  /  0.0014984
2000  /  300000  /  0.0006751
略
99000  /  300000  /  0.0000102
100000  /  300000  /  0.0000101
-------------------
1 - 0   -> 0.9989
1 - 1   -> 0.0011
0 - 1   -> 0.9989
0 - 0   -> 0.0012
---- End ----
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1