LoginSignup
1
2

More than 5 years have passed since last update.

Pythonでニューラルネット(AutoEncoderで「要約/次元圧縮」してみる)

Last updated at Posted at 2018-01-09

 前回記事
で作った、ニューラルネットで、AutoEncoderを試してみる。

ここで言うAutoEncoderは、入力と出力が同じ値となる、ニューラルネット。

8個の入力 → 3個の隠れ層 → 8個の出力

で、00000001、00000010、00000100、・・・10000000を学習させる。

以下結果。
左のブロックが入力、中が隠れ層、右が出力。
(学習の結果入力と出力が一致している)

[入力]                     [隠れ層]              [出力]
0  0  0  0  0  0  0  1  - 0.00  0.99  0.01   - 0.00  0.00  0.00  0.00  0.00  0.00  0.00  1.00   + 
0  0  0  0  0  0  1  0  - 0.99  0.97  0.00   - 0.00  0.00  0.00  0.00  0.00  0.00  1.00  0.00   + 
0  0  0  0  0  1  0  0  - 0.00  0.91  1.00   - 0.00  0.00  0.00  0.00  0.00  1.00  0.00  0.00   + 
0  0  0  0  1  0  0  0  - 0.00  0.02  0.00   - 0.00  0.00  0.00  0.00  1.00  0.00  0.00  0.00   + 
0  0  0  1  0  0  0  0  - 0.99  0.00  0.98   - 0.00  0.00  0.00  1.00  0.00  0.00  0.00  0.00   + 
0  0  1  0  0  0  0  0  - 0.90  0.01  0.00   - 0.00  0.00  1.00  0.00  0.00  0.00  0.00  0.00   + 
0  1  0  0  0  0  0  0  - 0.00  0.00  0.86   - 0.00  1.00  0.00  0.00  0.00  0.00  0.00  0.00   + 
1  0  0  0  0  0  0  0  - 1.00  1.00  1.00   - 1.00  0.00  0.00  0.00  0.00  0.00  0.00  0.00   + 

★ここで重要なのが、隠れ層。
8種の入力に対し、以下のような値を持っている。
これはあたかも、2進法の0〜7(2の3乗(=8))の表現を獲得しているよう。
→ AutoEncoderによる学習後、隠れ層を取り出すことで、
「要約/次元圧縮」が可能になることの一例。

010
110
010
000
101
100
001
111

※以下ソース (python3)

autoencoder.py
import random
import math

iINPUT = 8
iHIDDEN = 3
iOUTPUT = 8

iPR = 5000
iMAX_T = 500000
dETA = 2.5
dEPS = 0.000005
dALPHA = 0.92
dBETA = 0.35
dW0 = 0.9

xi = [0 for i in range(iINPUT +1)] 
v =  [0 for i in range(iHIDDEN+1)] 
o =  [0 for i in range(iOUTPUT)]
zeta =  [0 for i in range(iOUTPUT)] 

#------------
def arand():
    r = random.random()
    r = r * 2 *dW0 - dW0
    return r
#------------

w1 = [[arand() for i in range(iINPUT+1)]  for i in range(iHIDDEN)]
w2 = [[arand() for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]

d_w1 = [[0 for i in range(iINPUT+1)]  for i in range(iHIDDEN)]
d_w2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]

pre_dw1 = [[0 for i in range(iINPUT+1)]  for i in range(iHIDDEN)]
pre_dw2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]

iPATTERNz =8

#入力する値              
data =  [[0,0,0,0,0,0,0,1],\
         [0,0,0,0,0,0,1,0],\
         [0,0,0,0,0,1,0,0],\
         [0,0,0,0,1,0,0,0],\
         [0,0,0,1,0,0,0,0],\
         [0,0,1,0,0,0,0,0],\
         [0,1,0,0,0,0,0,0],\
         [1,0,0,0,0,0,0,0]]
for i in range(iPATTERNz):
    data[i].append(1)

#教師信号(所謂、模範解答)
t_data = data

iPATTERNo =8

#お試し用の問題
d_data =  [[0,0,0,0,0,0,0,1],\
           [0,0,0,0,0,0,1,0],\
           [0,0,0,0,0,1,0,0],\
           [0,0,0,0,1,0,0,0],\
           [0,0,0,1,0,0,0,0],\
           [0,0,1,0,0,0,0,0],\
           [0,1,0,0,0,0,0,0],\
           [1,0,0,0,0,0,0,0]]
for i in range(iPATTERNo):
    d_data[i].append(1)

##-----------------------------

def dw_init():
    global d_w1
    global d_w2
    global pre_dw1
    global pre_dw2
    pre_dw1 = d_w1
    pre_dw2 = d_w2
    d_w1 = [[0 for i in range(iINPUT+1)]  for i in range(iHIDDEN)]
    d_w2 = [[0 for i in range(iHIDDEN+1)] for i in range(iOUTPUT)]

def sigmoid(u):
    return  1.0 / (1.0 +  math.exp(-dBETA * u));

def xi_set(p):
    global xi
    global zeta
    xi = data[p]
    zeta = t_data[p]

def forward():
    global xi
    global w1
    global v
    global w2
    global o
    for j in range(iHIDDEN):
        sm = 0
        for k in range(iINPUT+1):
            sm += xi[k] * w1[j][k]
        v[j] = sigmoid(sm)
    v[iHIDDEN] = 1.0
    for j2 in range(iOUTPUT):
        sm = 0
        for j in range(iHIDDEN+1):
            sm += v[j] * w2[j2][j]
        o[j2] = sigmoid(sm)

def backward():
    global o
    global zeta
    global v
    global xi
    global d_w1
    global d_w2
    global w2

    delta2 = [0 for i in range(iOUTPUT)]
    delta1 = [0 for i in range(iHIDDEN+1)]
    for i in range(iOUTPUT):
        delta2[i] = dBETA * o[i] * (1-o[i]) * (zeta[i]-o[i])
    for j in range(iHIDDEN):
        sm = 0
        for j2 in range(iOUTPUT):
            sm = sm + w2[j2][j] * delta2[j2]
            delta1[j] = dBETA * v[j] * (1-v[j]) * sm
    for j2 in range(iOUTPUT):
        for j in range(iHIDDEN+1):
            d_w2[j2][j] = d_w2[j2][j] + delta2[j2] * v[j]
    for j in range(iHIDDEN):
        for k in range(iINPUT+1):
            d_w1[j][k] = d_w1[j][k] + delta1[j] * xi[k]

def w_modify():
    for j2 in range(iOUTPUT):
        for j in range(iHIDDEN+1):
            d_w2[j2][j] = dALPHA * dETA * d_w2[j2][j] + dALPHA * pre_dw2[j2][j]
            w2[j2][j] = w2[j2][j] + d_w2[j2][j]
    for j in range(iHIDDEN):
        for k in range(iINPUT+1):
            d_w1[j][k] = dALPHA * dETA * d_w1[j][k] + dALPHA * pre_dw1[j][k]
            w1[j][k] = w1[j][k] + d_w1[j][k]

def calc_error():
    e = 0.0
    for i in range(iOUTPUT):
        e = e +(zeta[i]-o[i])*(zeta[i]-o[i])
    return e

def back_propagation_main():
    global zeta
    e = 0.0
    esum = 0.0
    for t in range(iMAX_T):
        dw_init()
        esum = 0.0
        for p in range(iPATTERNz):
            xi_set(p)
            forward()
            backward()
            esum = esum + calc_error()
        w_modify()
        e = esum #/ (iOUTPUT * iPATTERNz)
        if t % iPR ==0:
            print(t, " / ", iMAX_T, " / ", "{0:.7f}".format(e))
        if(e<dEPS):
            break

def tryTest():
    global xi
    global o
    print("-------------------")
    for p in range(iPATTERNo):
        xi = d_data[p]
        forward()
        xi.pop()

        for i in range(iINPUT):
            print(xi[i]," ", end="")
        print("- ", end="")
        for i in range(iHIDDEN):
            print("{0:.2f}".format(v[i])," ", end="")
        print(" - ", end="")
        for i in range(iOUTPUT):
            print("{0:.2f}".format(o[i])," ", end="")
        print(" + ")

##-----------------------------


print ("---- Start ----")
back_propagation_main()
tryTest()
print("---- End ----")
1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2