Python で ニューラルネット(XOR回路を作ってみる)
→
XOR(排他的論理和)を、ニューラルネットに学習させる例
<XOR>
True と True で False (1 & 1 -> 0)
True と False で True (1 & 0 -> 1)
False と True で True (0 & 1 -> 1)
False と False で False (0 & 0 -> 0)
※global変数の多用については、眼をつぶって下さい、、、
※バックプロパゲーション(Backpropagation、誤差逆伝播法)をつかったベーシックな例です
NNetXOR.ps
import random
import math
iINPUT = 2
iHIDDEN = 4
iOUTPUT = 1
iPR = 1000
iMAX_T = 300000
dETA = 2.5
dEPS = 0.00001
dALPHA = 0.92
dBETA = 0.35
dW0 = 0.9
xi = [0 for i in range(iINPUT +1)]
v = [0 for i in range(iHIDDEN+1)]
o = [0 for i in range(iOUTPUT)]
zeta = [0 for i in range(iOUTPUT)]
#------------
def arand():
r = random.random()
r = r * 2 *dW0 - dW0
return r
#------------
w1 = [[arand() for i in range(iINPUT+1)] for i in range(iHIDDEN)]
w2 = [[arand() for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]
d_w1 = [[0 for i in range(iINPUT+1)] for i in range(iHIDDEN)]
d_w2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]
pre_dw1 = [[0 for i in range(iINPUT+1)] for i in range(iHIDDEN)]
pre_dw2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]
#入力する値
data = [[0,0],[0,1],[1,0],[1,1],[1,0],[1,1],[0,1],[0,0]]
for i in range(8):
data[i].append(1)
#教師信号(所謂、模範解答)
t_data =[[0], [1], [1], [0], [1], [0], [1], [0]]
#お試し用の問題
d_data =[[1,0],[1,1],[0,1],[0,0]]
for i in range(4):
d_data[i].append(1)
##-----------------------------
def dw_init():
global d_w1
global d_w2
global pre_dw1
global pre_dw2
pre_dw1 = d_w1
pre_dw2 = d_w2
d_w1 = [[0 for i in range(iINPUT+1)] for i in range(iHIDDEN)]
d_w2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]
def sigmoid(u):
return 1.0 / (1.0 + math.exp(-dBETA * u));
def xi_set(p):
global xi
global zeta
xi = data[p]
zeta = t_data[p]
def forward():
global xi
global w1
global v
global w2
global o
for j in range(iHIDDEN):
sm = 0
for k in range(iINPUT+1):
sm += xi[k] * w1[j][k]
v[j] = sigmoid(sm)
v[iHIDDEN] = 1.0
for j2 in range(iOUTPUT):
sm = 0
for j in range(iHIDDEN+1):
sm += v[j] * w2[j2][j]
o[j2] = sigmoid(sm)
def backward():
global o
global zeta
global v
global xi
global d_w1
global d_w2
global w2
delta2 = [0 for i in range(iOUTPUT)]
delta1 = [0 for i in range(iHIDDEN+1)]
for i in range(iOUTPUT):
delta2[i] = dBETA * o[i] * (1-o[i]) * (zeta[i]-o[i])
for j in range(iHIDDEN):
sm = 0
for j2 in range(iOUTPUT):
sm = sm + w2[j2][j] * delta2[j2]
delta1[j] = dBETA * v[j] * (1-v[j]) * sm
for j2 in range(iOUTPUT):
for j in range(iHIDDEN+1):
d_w2[j2][j] = d_w2[j2][j] + delta2[j2] * v[j]
for j in range(iHIDDEN):
for k in range(iINPUT+1):
d_w1[j][k] = d_w1[j][k] + delta1[j] * xi[k]
def w_modify():
for j2 in range(iOUTPUT):
for j in range(iHIDDEN+1):
d_w2[j2][j] = dALPHA * dETA * d_w2[j2][j] + dALPHA * pre_dw2[j2][j]
w2[j2][j] = w2[j2][j] + d_w2[j2][j]
for j in range(iHIDDEN):
for k in range(iINPUT+1):
d_w1[j][k] = dALPHA * dETA * d_w1[j][k] + dALPHA * pre_dw1[j][k]
w1[j][k] = w1[j][k] + d_w1[j][k]
def calc_error():
e = 0.0
for i in range(iOUTPUT):
e = e +(zeta[i]-o[i])*(zeta[i]-o[i])
return e
def back_propagation_main():
global zeta
e = 0.0
esum = 0.0
for t in range(iMAX_T):
dw_init()
esum = 0.0
for p in range(iPATTERNz):
xi_set(p)
forward()
backward()
esum = esum + calc_error()
w_modify()
e = esum #/ (iOUTPUT * iPATTERNz)
if t % iPR ==0:
print(t, " / ", iMAX_T, " / ", "{0:.7f}".format(e))
if(e<dEPS):
break
def tryTest():
global xi
global o
print("-------------------")
for p in range(iPATTERNo):
xi = d_data[p]
forward()
print(xi[0],"-",xi[1]," -> {0:.4f}".format(o[0]))
##-----------------------------
iPATTERNz =8
iPATTERNo =4
print ("---- Start ----")
back_propagation_main()
tryTest()
print("---- End ----")
結果はこちら
---- Start ----
0 / 300000 / 2.1139707
1000 / 300000 / 0.0014984
2000 / 300000 / 0.0006751
略
99000 / 300000 / 0.0000102
100000 / 300000 / 0.0000101
-------------------
1 - 0 -> 0.9989
1 - 1 -> 0.0011
0 - 1 -> 0.9989
0 - 0 -> 0.0012
---- End ----