LoginSignup
7
13

More than 3 years have passed since last update.

ディープラーニングを実装から学ぶ(10-1)RNNの実装(RNN,LSTM,GRU)

Last updated at Posted at 2020-04-22

今更ですが、RNNについてです。
RNNもCNNと同様に約2年前に実装していましたが、なかなか書けませんでした。少し時間ができたので、書きます。
RNNですが、例によってMNISTを使って確かめます。

時系列データ

RNNは、例えば、株価の推移、商品の売り上げなど時刻ごとに変化するデータの予測に用いられます。
次元としては、以下のような2次元データです。

(t, d) t:時系列長、d:説明変数

tは、月次データの12か月分であれば12、日々データで1週間分であれば7になります。

ここでは、MNISTの画像を時系列データとみなします。
MNISTの画像の例です。
MNIST.png

以下のように、上部のピクセルから順に1時刻、2時刻となり最後が28時刻です。
(28,28)のデータになります。
MNIST_time.png
スキャナで上から順番に読み込んでいくイメージです。

RNN(リカレントニューラルネットワーク)

時系列データは、前時刻までのデータが次のデータに影響を与えます。
RNNは、以下の構造をしています。
$ u_1,u_2,u_3,・・・,u_n $は、各時刻のデータを表します。
MNISTでは、$ n=28 $で、画像上段のピクセルから順番に最下段のピクセルの値を表します。
前時刻までのデータ情報を$ h_1,h_2,h_3,・・・,h_{n-1} $に保持していきます。
最後の$ h_n $が最終的な出力です。
$ h $のサイズは、全結合のノード数と同様にハイパーパラメータとして決定します。

RNN.png

シンプルRNN

RNNは、リカレントニューラルネットワークの仕組みを表す広義のRNNと、上図のRNNの構造を示す狭義のRNNがあります。ここでは、後者を「シンプルRNN」と呼びます。
シンプルRNNの内部構造を以下に示します。
smpleRNN.png
$ W_x $ 、 $ W_h $は、重さ、$ b $はバイアスです。$ \mathrm{tanh} $は、活性化関数です。
式で表すと以下のようになります。

$ h_t = \mathrm{tanh}(h_{t-1} \cdot W_h+u_t \cdot W_x + b) $

前時刻からの入力がない場合を考えましょう。
これは、図の書き方が違いますが、活性化関数を$ \mathrm{tanh} $とした場合の全結合と同じです。
smpleRNN2.png

$ h_t = \mathrm{tanh}(u_t \cdot W_x + b) $

これに、前時刻からの情報$ h_{t_1} $と重さを掛けた値を加えることで、前時刻までの情報を反映させます。
smpleRNN.png

$ h_t = \mathrm{tanh}(h_{t-1} \cdot W_h+u_t \cdot W_x + b) $

順伝播

シンプルRNNの順伝播のプログラムは、以下になります。

h_t = tanh(np.dot(h_t_1, Wh) + np.dot(u_t, Wx) + b)

これを時系列長分、連続で実行していきます。

RNN.png

RNNの関数を定義します。
時系列長($ t $)分ループします。その際、$ h_t $を次の$ h_{t-1} $とします。各時刻の$ h $を逆伝播時に利用するため保持しておきます。
活性化関数は、基本的に$ \mathrm{tanh} $を用いますが変更できるように、パラメータ化しました。
参考までに、各変数の次元をコメントで示します。
n - バッチサイズ
t - 時系列長(MNISTでは縦方向のピクセル数)
d - 説明変数の数(MNISTでは、横方向のピクセル数)
h - 出力次元数(ノード数)

def RNN(u, h_0, Wx, Wh, b, propagation_func=tanh):
    # u - (n, t, d)
    un, ut, ud = u.shape
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    h_t_1 = h_0
    hs = np.zeros((hn, ut, hh))
    for t in range(ut):
        # RNN
        u_t = u[:,t,:]
        h_t = propagation_func(np.dot(h_t_1, Wh) + np.dot(u_t, Wx) + b)
        h_t_1 = h_t
        # hの値保持
        hs[:,t,:] = h_t
    # 後設定
    h_n = h_t
    return h_n, hs

逆伝播

逆伝播を考えます。
先ほどの図に、勾配の計算を赤字で示します。
smpleRNN3.png

矢印の向きは、順伝播から逆にしています。逆伝播は、出力側から考えるのでした。
$ h_t $から$ u_t, h_{t-1}, W_x, W_h, b $のそれぞれまで、逆向きの矢印をたどり、赤字を値を掛けていきます。まずは、赤字の求め方です。

それぞれの勾配は、$ dh_t $から順番に掛け算をしていきます。

  • $ u_t $方向
    $ (dh_t \times \mathrm{tanh}' \times 1) \cdot W_x^T $

  • $ h_{t-1} $方向
    $ (dh_t \times \mathrm{tanh}' \times 1) \cdot W_t^T $

  • $ W_x $方向
    $ u_t^T \cdot (dh_t \times \mathrm{tanh}' \times 1) $

  • $ W_h $方向
    $ h_{t-1}^T \cdot (dh_t \times \mathrm{tanh}' \times 1) $

  • $ b $方向
    $ dh_t \times \mathrm{tanh}' \times 1 $

プログラムです。
まず、$ \mathrm{tanh} $の勾配をtanh_backで計算しています。本来であれば2つ目のパラメータは、$ \mathrm{tanh} $の計算前の値を設定するのですが、tanh_backでは利用していないためここではNoneを指定しています。本来であれば、勾配関数仕様を見直すべきですが、影響が大きいためここではこのままとさせてください。
あとは、$ \mathrm{tanh} $の勾配$ dp $を利用し、それぞれの勾配を計算します。
$ b $は、バッチサイズ分足し合わせます。

        # tanhの勾配
        dp = tanh_back(dh_t, None, h_t)
        # 各勾配
        db  = np.sum(dp, axis=0)
        dWh = np.dot(h_t_1.T, dp)
        dWx = np.dot(u_t.T, dp)
        dh_t_1 = np.dot(dp, Wh.T)
        du_t = np.dot(dp, Wx.T)

時系列全体の勾配関数を定義します。
時系列長($ t $)分ループします。
各重み、バイアスについては、各時刻ごとの勾配がすべて作用します。そのため、時系列分加えていきます。
勾配計算に必要な、$ u_t $、$ h_t $、$ h_{t-1} $を設定し順番に勾配を計算していきます。
$ dh_{t-1} $を次の$ dh_t $とします。

def RNN_back(dh_n, u, hs, h_0, Wx, Wh, b, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    # hs - (n, t, h)
    # h_0 - (n, h)
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    dh_t  = dh_n
    du  = np.zeros_like(u)
    dWx = np.zeros_like(Wx)
    dWh = np.zeros_like(Wh)
    db  = np.zeros_like(b)
    for t in reversed(range(ut)):
        # RNN勾配計算
        u_t = u[:,t,:]
        h_t = hs[:,t,:]
        if t == 0:
            h_t_1 = h_0
        else:
            h_t_1 = hs[:,t-1,:]
        # tanhの勾配
        dp = propagation_back_func(dh_t, None, h_t)
        # 各勾配
        # h =>  (n, h)
        db  += np.sum(dp, axis=0)
        # (h, h) => (h, n) @ (n, h)
        dWh += np.dot(h_t_1.T, dp)
        # (d, h) => (d, n) @ (n, h)
        dWx += np.dot(u_t.T, dp)
        # (n, h) => (n, h) @ (h, h)
        dh_t_1 = np.dot(dp, Wh.T)
        # (n, d) => (n, h) @ (h, d)
        du_t = np.dot(dp, Wx.T)
        dh_t = dh_t_1
        # uの勾配の値保持
        du[:,t,:] = du_t
    # 後設定
    dh_0 = dh_t_1
    return du, dWx, dWh, db, dh_0

多階層

RNNは、全結合やCNNと同様に、複数層積み上げることで精度が向上します。
2層の例です。
RNN2.png
1層目の各時刻の出力は、2層目の入力となります。次時刻に渡す$ h_t $と、次の層の入力として渡す$ hs[t] $は同じものです。
smpleRNN5.png

順伝播のプログラムでは、$ hs $として、各時刻の$ h_t $をまとめて出力していましたので、次の層に$ hs $を渡すだけでプログラムの変更はありません。

逆伝播は、$ hs $側からも勾配が戻されます。$ h_t $側と加える必要があります。
smpleRNN4.png

変更後のプログラムです。
$ dh_t $と$ dhs[t] $を加えています。

def RNN_back(dh_n, dhs, u, hs, h_0, Wx, Wh, b, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    # hs - (n, t, h)
    # h_0 - (n, h)
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    dh_t  = dh_n
    du  = np.zeros_like(u)
    dWx = np.zeros_like(Wx)
    dWh = np.zeros_like(Wh)
    db  = np.zeros_like(b)
    for t in reversed(range(ut)):
        # RNN勾配計算
        u_t = u[:,t,:]
        h_t = hs[:,t,:]
        if t == 0:
            h_t_1 = h_0
        else:
            h_t_1 = hs[:,t-1,:]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[:,t,:]
        # tanhの勾配
        dp = propagation_back_func(dh_t, None, h_t)
        # 各勾配
        # h =>  (n, h)
        db  += np.sum(dp, axis=0)
        # (h, h) => (h, n) @ (n, h)
        dWh += np.dot(h_t_1.T, dp)
        # (d, h) => (d, n) @ (n, h)
        dWx += np.dot(u_t.T, dp)
        # (n, h) => (n, h) @ (h, h)
        dh_t_1 = np.dot(dp, Wh.T)
        # (n, d) => (n, h) @ (h, d)
        du_t = np.dot(dp, Wx.T)
        dh_t = dh_t_1
        # uの勾配の値保持
        du[:,t,:] = du_t
    # 後設定
    dh_0 = dh_t_1
    return du, dWx, dWh, db, dh_0

フレームワークの対応

RNN複数層に対応させます。
RNNの最終層では、$ h_n $を次の層に渡します。最終層以外では、$ hs $を次に渡します。
RNN3.png
返却するデータを区別するため、return_hsパラメータを用います。$ hs $を返す場合は、Trueを設定します。Falseの場合は、$ h_n $を返します。

「ディープラーニングを実装から学ぶ(8)実装変更」にRNNを組み込みます。

初期化

重みの初期化は、$ \mathrm{tanh} $を利用するため、既定値をglorot_normalにしています。

def RNN_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wx = Wx_init_func(d, h, **Wx_init_params)
    Wh = Wh_init_func(h, h, **Wh_init_params)
    b = bias_init_func(h, **bias_init_params)
    return d_next, {"Wx":Wx, "Wh":Wh, "b":b}

def RNN_init_optimizer():
    sWx = {}
    sWh = {}
    sb = {}
    return {"sWx":sWx, "sWh":sWh, "sb":sb}

順伝播

return_hsに応じて、$ hs $または$ h_n $を返却するようにしています。

def RNN_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Wh"].shape[0]))
    # RNN
    h_n, hs = func(u, h_0, weights["Wx"], weights["Wh"], weights["b"], propagation_func)
    # RNN最下層以外    
    if params.get("return_hs"):
        z = hs
    # RNN最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wx"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wh"], **weight_decay["params"])
    return {"u":u, "z":z, "h_0":h_0, "hs":hs}, weight_decay_r

逆伝播

return_hsに応じて、どちらからの勾配を利用するか判定しています。
return_hsがTrueの場合、$ dhs $、Flaseの場合は、$ dh_n $とします。

def RNN_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # RNN最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"])
        dhs = dz
    # RNN最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"])
    # RNNの勾配計算
    du, dWx, dWh, db, dh_0 = back_func(dh_n, dhs, us["u"], us["hs"], us["h_0"], weights["Wx"], weights["Wh"], weights["b"], propagation_func)
    # 重み減衰対応
    if weight_decay is not None:
        dWx += weight_decay["back_func"](weights["Wx"], **weight_decay["params"])
        dWh += weight_decay["back_func"](weights["Wh"], **weight_decay["params"])
    return {"du":du, "dWx":dWx, "dWh":dWh, "db":db}

重みの更新

$ W_x $、$ W_h $、$ b $を更新します。

def RNN_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wx"], optimizer_stats["sWx"] = func(weights["Wx"], du["dWx"], **params, **optimizer_stats["sWx"])
    weights["Wh"], optimizer_stats["sWh"] = func(weights["Wh"], du["dWh"], **params, **optimizer_stats["sWh"])
    weights["b"],  optimizer_stats["sb"]  = func(weights["b"],  du["db"],  **params, **optimizer_stats["sb"])
    return weights, optimizer_stats

性能改善

RNNは、時刻単位に順番に実行していきます。$ u $は、時系列長の次元が2つめの次元です。アクセスは、$ u[:,t,:] $のように行います。しかし、このアクセス方法では時間がかかります。あらかじめ、transposeにより時系列長が最初の次元となるように並び替えを行います。transpose自体には時間がかかりますが、その後のアクセスは早くなります。時系列長回のアクセスを考えると、最初に並び替えたほうがトータルで早くなります。
 transpose前 : (バッチサイズ, 時系列長, 説明変数の数)
 transpose後 : (時系列長, バッチサイズ, 説明変数の数)
他の変数も同様にtransposeします。

def RNN(u, h_0, Wx, Wh, b, propagation_func=tanh):
    # u - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    h_t_1 = h_0
    hs = np.zeros((ut, hn, hh))
    for t in range(ut):
        # RNN
        u_t = u[t]
        h_t = propagation_func(np.dot(h_t_1, Wh) + np.dot(u_t, Wx) + b)
        h_t_1 = h_t
        # hの値保持
        hs[t] = h_t
    # 後設定
    h_n = h_t
    hs = hs.transpose(1,0,2)
    return h_n, hs

def RNN_back(dh_n, dhs, u, hs, h_0, Wx, Wh, b, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    dh_t  = dh_n
    du  = np.zeros_like(u)
    dWx = np.zeros_like(Wx)
    dWh = np.zeros_like(Wh)
    db  = np.zeros_like(b)
    for t in reversed(range(ut)):
        # RNN勾配計算
        u_t = u[t]
        h_t = hs[t]
        if t == 0:
            h_t_1 = h_0
        else:
            h_t_1 = hs[t-1]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]
        # tanhの勾配
        dp = propagation_back_func(dh_t, None, h_t)
        # 各勾配
        # h =>  (n, h)
        db  += np.sum(dp, axis=0)
        # (h, h) => (h, n) @ (n, h)
        dWh += np.dot(h_t_1.T, dp)
        # (d, h) => (d, n) @ (n, h)
        dWx += np.dot(u_t.T, dp)
        # (n, h) => (n, h) @ (h, h)
        dh_t_1 = np.dot(dp, Wh.T)
        # (n, d) => (n, h) @ (h, d)
        du_t = np.dot(dp, Wx.T)
        dh_t = dh_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    du = du.transpose(1,0,2)
    return du, dWx, dWh, db, dh_0

実行例

「ディープラーニングを実装から学ぶ(8)実装変更」の「参考」-「プログラム全体」のプログラムを事前に実行しておきます。
MNISTのデータを読み込みます。(MNISTのデータファイルが、c:\mnistに格納されている例)

x_train, t_train, x_test, t_test = load_mnist('c:\\mnist\\')
data_normalizer = create_data_normalizer(min_max)
nx_train, data_normalizer_stats = train_data_normalize(data_normalizer, x_train)
nx_test                         = test_data_normalize(data_normalizer, x_test)

時系列データとなるようにreshapeします。

nx_train = nx_train.reshape((nx_train.shape[0], 28, 28))
nx_test  = nx_test.reshape((nx_test.shape[0], 28, 28))

モデルを定義します。出力次元数は100にします。
RNNの1層です。

model = create_model((28,28))
model = add_layer(model, "RNN1", RNN, 100)
model = add_layer(model, "affine", affine, 10)
model = set_output(model, softmax)
model = set_error(model, cross_entropy_error)
optimizer = create_optimizer(SGD, lr=0.01)

30エポック実行します。

epoch = 30
batch_size = 100
np.random.seed(10)
model, optimizer, learn_info = learn(model, nx_train, t_train, nx_test, t_test, batch_size=batch_size, epoch=epoch, optimizer=optimizer)

結果です。

input - 0 (28, 28)
RNN1 RNN (28, 28) 100
affine affine 100 10
output softmax 10
error cross_entropy_error
0 0.03518333333333333 2.4909296350219856 0.0337 2.5033102068320403 
1 0.6705 1.0635964290350581 0.7834 0.7116426687585262 
2 0.8527333333333333 0.4957272344480599 0.8576 0.46864695385136285 
3 0.9009833333333334 0.34436341886017446 0.9249 0.2647796760544915 
4 0.9206 0.2730971587148404 0.9314 0.22924864598081818 
5 0.9312666666666667 0.23532134718192238 0.9367 0.21693257443611919 
6 0.9402166666666667 0.2058781614525009 0.9441 0.19102061343584434 
7 0.94545 0.18925935088869394 0.9545 0.16137079638287335 
8 0.9502 0.17225559680178326 0.9551 0.15719143962718626 
9 0.9547 0.15923950747415935 0.9557 0.14887872338275715 
10 0.95815 0.1491376022136765 0.9587 0.14576495429428063 
11 0.9593166666666667 0.14198676706280589 0.9601 0.13665127388759685 
12 0.9624 0.13373739559712328 0.9626 0.13021487041349578 
13 0.9637833333333333 0.12716385239117425 0.9641 0.1250842946191389 
14 0.9644833333333334 0.12442839761307002 0.9637 0.1237039736621491 
15 0.9665333333333334 0.11724673298424199 0.9608 0.13436264886904448 
16 0.9686166666666667 0.11296855258609663 0.9657 0.12175114987449233 
17 0.9693333333333334 0.10905979233467651 0.9682 0.11534971758957534 
18 0.9698166666666667 0.10555725650696138 0.9677 0.11360613475409218 
19 0.9704166666666667 0.10316100002412107 0.9669 0.11402197839216774 
20 0.9718833333333333 0.0973303874617204 0.9699 0.10450096085064706 
21 0.97125 0.09675419492450306 0.968 0.10589076380677737 
22 0.9729166666666667 0.09542962153214306 0.9687 0.10788647858006394 
23 0.9737833333333333 0.09139111488900863 0.9689 0.10552085247884538 
24 0.97455 0.08763460395473652 0.97 0.1053904978983531 
25 0.9757333333333333 0.08342315854838207 0.9664 0.1142366242901633 
26 0.9762833333333333 0.08293232935406979 0.9719 0.09742873526946 
27 0.97695 0.07981604296888609 0.9718 0.09424186783390053 
28 0.9774 0.07762622005906694 0.9665 0.11621772074040007 
29 0.9774166666666667 0.07861949031260665 0.9714 0.10051382095845361 
30 0.9785833333333334 0.07415927952523564 0.9735 0.09602163729199646 
所要時間 = 5 分 19 秒

ある程度の精度になりましたが、全結合の場合より精度が低いようです。
シンプルRNNでは、逆伝播時の勾配消失、勾配爆発により時系列長が長い場合は精度が出ないようです。
この後、精度を向上する手法を考えます。

LSTM(Long short-term memory)

LSTMの構造を図示します。シンプルRNNに比べてかなり複雑になりました。
シンプルRNNの構造に加えて、メモリセルと3つのゲートを持ちます。
ひとつひとつ分解して確認していきましょう。
LSTM.png

基本部

基本の部分は、シンプルRNNと同じです。
ここでは、シンプルRNNの出力を$ g $とします。

$ g_t = \mathrm{tanh}(h_{t-1} \cdot W_{hg}+u_t \cdot W_{xg} + b_g) $

LSTM2.png

メモリセル(CEC:Constant Error Carousel)

シンプルRNNは、誤差逆伝播時に、時刻をさかのぼるごとに勾配の値が小さくなる勾配消失が原因です。時刻ごとに勾配を掛け合わせるため、例えば、tanhの勾配が$ 0.5 $だとすると28時刻さかのぼれば、$ 0.5^{28}=0.00000000373… $となります。
そこで、CECと呼ばれる勾配を維持するためのメモリセルを用います。ここに勾配をとどまらせることによって、勾配消失を防ぎます。
シンプルRNNの結果とメモリセルの値を加えて、活性化関数と通します。
メモリセルに乗せる値については、次に紹介します。
LSTM3.png

忘却ゲート

メモリセルに過去からの情報を保持してます。過去からのメモリセルの値をどの程度、次に伝えるかを決めるのが忘却ゲートです。
過去のメモリセルの値をどの程度通すかをsigmoidにより、0~1の確率で求めます。0なら過去の情報をすべて忘れます(値を0にする)。0.5なら半分忘れます(0.5を掛けて値を半分にする)。

$ f_t = \mathrm{sigmoid}(h_{t-1} \cdot W_{hf}+u_t \cdot W_{xf} + b_f) $

この値をメモリセルにかけます。

LSTM4.png

入力ゲート

今度は、新たに現時刻のデータをどの程度メモリセルに乗せるか決定します。
忘却ゲート同様に、sigmoidにて確率を求めます。

$ i_t = \mathrm{sigmoid}(h_{t-1} \cdot W_{hi}+u_t \cdot W_{xi} + b_i) $

LSTM5.png

この値をシンプルRNNの結果にかけます。
結果をメモリセルに加えます。この値が次の時刻に渡すメモリセルの値となります。

出力ゲート

メモリセルと入力ゲートの値を足したものに活性化関数を通し、出力ゲートの確率を掛け最終的な出力とします。

$ o_t = \mathrm{sigmoid}(h_{t-1} \cdot W_{ho}+u_t \cdot W_{xo} + b_o) $

LSTM6.png

順伝播

LSTMの式をまとめると以下のようになります。

\begin{align}
f_t &= \mathrm{sigmoid}(h_{t-1} \cdot W_{hf}+u_t \cdot W_{xf} + b_f)\\
i_t &= \mathrm{sigmoid}(h_{t-1} \cdot W_{hi}+u_t \cdot W_{xi} + b_i)\\
g_t &= \mathrm{tanh}(h_{t-1} \cdot W_{hg}+u_t \cdot W_{xg} + b_g)\\
o_t &= \mathrm{sigmoid}(h_{t-1} \cdot W_{ho}+u_t \cdot W_{xo} + b_o)\\
c_t &= f_t \times c_{t-1} + i_t \times g_t\\
h_t &= o_t \times \mathrm{tanh}(c_t)
\end{align}

LSTM7.png

実装です。
性能改善のため、transposeを行っています。

def LSTM(u, h_0, c_0, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh):
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # c_0  - (n, h)
    cn, ch = c_0.shape
    # Wxf, Wxi, Wxg, Who - (d, h)
    # Whf, Whi, Whg, Who - (h, h)
    # bf,  bi,  bg,  bo  - h
    # 初期設定
    h_t_1 = h_0
    c_t_1 = c_0
    hs = np.zeros((ut, hn, hh))
    cs = np.zeros((ut, cn, ch))
    fs = np.zeros((ut, hn, hh))
    Is = np.zeros((ut, hn, hh))
    gs = np.zeros((ut, hn, hh))
    os = np.zeros((ut, hn, hh))
    for t in range(ut):
        # LSTM
        u_t = u[t]
        f_t = sigmoid(np.dot(h_t_1, Whf) + np.dot(u_t, Wxf) + bf)
        i_t = sigmoid(np.dot(h_t_1, Whi) + np.dot(u_t, Wxi) + bi)
        g_t = propagation_func(np.dot(h_t_1, Whg) + np.dot(u_t, Wxg) + bg)
        o_t = sigmoid(np.dot(h_t_1, Who) + np.dot(u_t, Wxo) + bo)

        c_t = f_t * c_t_1 + i_t * g_t
        h_t = o_t * propagation_func(c_t)
        # h,c,f,i,g,oの値保持
        hs[t] = h_t
        cs[t] = c_t
        fs[t] = f_t
        Is[t] = i_t
        gs[t] = g_t
        os[t] = o_t
        h_t_1 = h_t
        c_t_1 = c_t
    # 後設定
    h_n = h_t
    c_n = c_t
    hs = hs.transpose(1,0,2)
    cs = cs.transpose(1,0,2)
    return h_n, hs, c_n, cs, fs, Is, gs, os

逆伝播

逆伝播を考えます。
$ h_t $から逆向きの矢印にそって、$ u_t, h_{t-1}, c_{t-1}, W_{xf}, W_{hf}, b_f, W_{xi}, W_{hi}, b_i, W_{xg}, W_{hg}, b_g, W_{xo}, W_{ho}, b_o $まで赤字の勾配を掛けていきます。

LSTM8.png

各勾配の求め方は、以下を参考にしてください。

あとは、逆向きの矢印を順番に掛け算していくだけです。

def LSTM_back(dh_n, dhs, u, hs, h_0, cs, c_0, fs, Is, gs, os, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # cs - (n, t, h)
    cs = cs.transpose(1,0,2)
    # c_0 - (n, h)
    # fs, Is, gs, os - (n, t, h)
    # Wxf, Wxi, Wxg, Who - (d, h)
    # Whf, Whi, Whg, Who - (h, h)
    # bf,  bi,  bg,  bo  - h
    # 初期設定
    dh_t  = dh_n
    dc_t  = np.zeros_like(c_0)
    du  = np.zeros_like(u)
    dWxf = np.zeros_like(Wxf)
    dWhf = np.zeros_like(Whf)
    dbf  = np.zeros_like(bf)
    dWxi = np.zeros_like(Wxi)
    dWhi = np.zeros_like(Whi)
    dbi  = np.zeros_like(bi)
    dWxg = np.zeros_like(Wxg)
    dWhg = np.zeros_like(Whg)
    dbg  = np.zeros_like(bg)
    dWxo = np.zeros_like(Wxo)
    dWho = np.zeros_like(Who)
    dbo  = np.zeros_like(bo)
    for t in reversed(range(ut)):
        # LSTM勾配計算
        u_t = u[t]
        h_t = hs[t]
        c_t = cs[t]
        if t == 0:
            h_t_1 = h_0
            c_t_1 = c_0
        else:
            h_t_1 = hs[t-1]
            c_t_1 = cs[t-1]
        o_t = os[t]
        g_t = gs[t]
        i_t = Is[t]
        f_t = fs[t]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]        
        # tanhの勾配
        dp = dh_t * o_t
        p_t = propagation_func(c_t)
        dc_t = dc_t + propagation_back_func(dp, None, p_t)
        # 各勾配
        # o
        do = dh_t * p_t
        dpo = sigmoid_back(do, None, o_t)
        dbo  += np.sum(dpo, axis=0)
        dWho += np.dot(h_t_1.T, dpo) 
        dWxo += np.dot(u_t.T, dpo)
        # g
        dg = dc_t * i_t
        dpg = propagation_back_func(dg, None, g_t)
        dbg  += np.sum(dpg, axis=0)
        dWhg += np.dot(h_t_1.T, dpg) 
        dWxg += np.dot(u_t.T, dpg)
        # i
        di = dc_t * g_t
        dpi = sigmoid_back(di, None, i_t)
        dbi  += np.sum(dpi, axis=0)
        dWhi += np.dot(h_t_1.T, dpi) 
        dWxi += np.dot(u_t.T, dpi)
        # f
        df = dc_t * c_t_1
        dpf = sigmoid_back(df, None, f_t)
        dbf  += np.sum(dpf, axis=0)
        dWhf += np.dot(h_t_1.T, dpf) 
        dWxf += np.dot(u_t.T, dpf)
        # c
        dc_t_1 = dc_t * f_t
        # h
        dh_t_1 = np.dot(dpf, Whf.T) + np.dot(dpi, Whi.T) + np.dot(dpg, Whg.T) + np.dot(dpo, Who.T)
        # u
        du_t = np.dot(dpf, Wxf.T) + np.dot(dpi, Wxi.T) + np.dot(dpg, Wxg.T) + np.dot(dpo, Wxo.T)
        dh_t = dh_t_1
        dc_t = dc_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    dc_0 = dc_t_1
    du = du.transpose(1,0,2)

    return du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0

フレームワーク対応

「ディープラーニングを実装から学ぶ(8)実装変更」にLSTMを組み込みます。
シンプルRNNと同様に対応します。

def LSTM_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wxf = Wx_init_func(d, h, **Wx_init_params)
    Whf = Wh_init_func(h, h, **Wh_init_params)
    bf = bias_init_func(h, **bias_init_params)
    Wxi = Wx_init_func(d, h, **Wx_init_params)
    Whi = Wh_init_func(h, h, **Wh_init_params)
    bi = bias_init_func(h, **bias_init_params)
    Wxg = Wx_init_func(d, h, **Wx_init_params)
    Whg = Wh_init_func(h, h, **Wh_init_params)
    bg = bias_init_func(h, **bias_init_params)
    Wxo = Wx_init_func(d, h, **Wx_init_params)
    Who = Wh_init_func(h, h, **Wh_init_params)
    bo = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxf":Wxf, "Whf":Whf, "bf":bf, "Wxi":Wxi, "Whi":Whi, "bi":bi, "Wxg":Wxg, "Whg":Whg, "bg":bg, "Wxo":Wxo, "Who":Who, "bo":bo}

def LSTM_init_optimizer():
    sWxf = {}
    sWhf = {}
    sbf = {}
    sWxi = {}
    sWhi = {}
    sbi = {}
    sWxg = {}
    sWhg = {}
    sbg = {}
    sWxo = {}
    sWho = {}
    sbo = {}
    return {"sWxf":sWxf, "sWhf":sWhf, "sbf":sbf,
             "sWxi":sWxi, "sWhi":sWhi, "sbi":sbi,
             "sWxg":sWxg, "sWhg":sWhg, "sbg":sbg,
             "sWxo":sWxo, "sWho":sWho, "sbo":sbo}

def LSTM_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Whf"].shape[0]))
    c_0 = params.get("c_0")
    if c_0 is None:
        c_0 = np.zeros((u.shape[0], weights["Whf"].shape[0]))
    # LSTM
    h_n, hs, c_n, cs, fs, Is, gs, os = func(u, h_0, c_0,
                                        weights["Wxf"], weights["Whf"], weights["bf"], 
                                        weights["Wxi"], weights["Whi"], weights["bi"], 
                                        weights["Wxg"], weights["Whg"], weights["bg"], 
                                        weights["Wxo"], weights["Who"], weights["bo"], propagation_func)
    # LSTM最下層以外 
    if params.get("return_hs"):
        z = hs
    # LSTM最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wxf"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whf"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxi"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whi"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxg"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whg"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxo"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Who"], **weight_decay["params"])

    return {"u":u, "z":z, "h_0":h_0, "hs":hs, "c_0":c_0, "cs":cs,
             "fs":fs, "Is":Is, "gs":gs, "os":os}, weight_decay_r

def LSTM_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # LSTM最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"]) 
        dhs = dz
    # LSTM最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"]) 
    # LSTMの勾配計算
    du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0 = back_func(dh_n, dhs,
                                        us["u"], us["hs"], us["h_0"], us["cs"], us["c_0"],
                                        us["fs"], us["Is"], us["gs"], us["os"],
                                        weights["Wxf"], weights["Whf"], weights["bf"],
                                        weights["Wxi"], weights["Whi"], weights["bi"],
                                        weights["Wxg"], weights["Whg"], weights["bg"],
                                        weights["Wxo"], weights["Who"], weights["bo"],
                                        propagation_func)

    # 重み減衰対応
    if weight_decay is not None:
        dWxf += weight_decay["back_func"](weights["Wxf"], **weight_decay["params"])
        dWhf += weight_decay["back_func"](weights["Whf"], **weight_decay["params"])
        dWxi += weight_decay["back_func"](weights["Wxi"], **weight_decay["params"])
        dWhi += weight_decay["back_func"](weights["Whi"], **weight_decay["params"])
        dWxg += weight_decay["back_func"](weights["Wxg"], **weight_decay["params"])
        dWhg += weight_decay["back_func"](weights["Whg"], **weight_decay["params"])
        dWxo += weight_decay["back_func"](weights["Wxo"], **weight_decay["params"])
        dWho += weight_decay["back_func"](weights["Who"], **weight_decay["params"])

    return {"du":du, "dWxf":dWxf, "dWhf":dWhf, "dbf":dbf, "dWxi":dWxi, "dWhi":dWhi, "dbi":dbi, "dWxg":dWxg, "dWhg":dWhg, "dbg":dbg, "dWxo":dWxo, "dWho":dWho, "dbo":dbo}

def LSTM_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wxf"], optimizer_stats["sWxf"] = func(weights["Wxf"], du["dWxf"], **params, **optimizer_stats["sWxf"])
    weights["Whf"], optimizer_stats["sWhf"] = func(weights["Whf"], du["dWhf"], **params, **optimizer_stats["sWhf"])
    weights["bf"],  optimizer_stats["sbf"]  = func(weights["bf"],  du["dbf"],  **params, **optimizer_stats["sbf"])
    weights["Wxi"], optimizer_stats["sWxi"] = func(weights["Wxi"], du["dWxi"], **params, **optimizer_stats["sWxi"])
    weights["Whi"], optimizer_stats["sWhi"] = func(weights["Whi"], du["dWhi"], **params, **optimizer_stats["sWhi"])
    weights["bi"],  optimizer_stats["sbi"]  = func(weights["bi"],  du["dbi"],  **params, **optimizer_stats["sbi"])
    weights["Wxg"], optimizer_stats["sWxg"] = func(weights["Wxg"], du["dWxg"], **params, **optimizer_stats["sWxg"])
    weights["Whg"], optimizer_stats["sWhg"] = func(weights["Whg"], du["dWhg"], **params, **optimizer_stats["sWhg"])
    weights["bg"],  optimizer_stats["sbg"]  = func(weights["bg"],  du["dbg"],  **params, **optimizer_stats["sbg"])
    weights["Wxo"], optimizer_stats["sWxo"] = func(weights["Wxo"], du["dWxo"], **params, **optimizer_stats["sWxo"])
    weights["Who"], optimizer_stats["sWho"] = func(weights["Who"], du["dWho"], **params, **optimizer_stats["sWho"])
    weights["bo"],  optimizer_stats["sbo"]  = func(weights["bo"],  du["dbo"],  **params, **optimizer_stats["sbo"])
    return weights, optimizer_stats

実行例

「ディープラーニングを実装から学ぶ(8)実装変更」の「参考」-「プログラム全体」のプログラムを事前に実行しておきます。
MNISTのデータを読み込みます。(MNISTのデータファイルが、c:\mnistに格納されている例)

x_train, t_train, x_test, t_test = load_mnist('c:\\mnist\\')
data_normalizer = create_data_normalizer(min_max)
nx_train, data_normalizer_stats = train_data_normalize(data_normalizer, x_train)
nx_test                         = test_data_normalize(data_normalizer, x_test)

時系列データとなるようにreshapeします。

nx_train = nx_train.reshape((nx_train.shape[0], 28, 28))
nx_test  = nx_test.reshape((nx_test.shape[0], 28, 28))

モデルを定義します。出力次元数は100にします。
LSTMの1層です。

model = create_model((28,28))
model = add_layer(model, "LSTM1", LSTM, 100)
model = add_layer(model, "affine", affine, 10)
model = set_output(model, softmax)
model = set_error(model, cross_entropy_error)
optimizer = create_optimizer(SGD, lr=1.0)

30エポック実行します。

epoch = 30
batch_size = 100
np.random.seed(10)
model, optimizer, learn_info = learn(model, nx_train, t_train, nx_test, t_test, batch_size=batch_size, epoch=epoch, optimizer=optimizer)

結果です。

input - 0 (28, 28)
LSTM1 LSTM (28, 28) 100
affine affine 100 10
output softmax 10
error cross_entropy_error
0 0.1496 2.302723152253055 0.1467 2.303348297247249 
1 0.8047 0.5757895576715077 0.962 0.1242534791640316 
2 0.9653333333333334 0.11828992821395005 0.9777 0.07833614483687754 
3 0.97785 0.07495387615148981 0.977 0.07639111708284944 
4 0.9834166666666667 0.05717556976747084 0.9822 0.06077722622599866 
5 0.9858166666666667 0.04674912358019147 0.9837 0.0538999396528086 
6 0.98755 0.039783345196157664 0.9838 0.05324526767631925 
7 0.99085 0.031220831213503143 0.9843 0.05004490986776939 
8 0.99185 0.026974026814119555 0.9856 0.05290982319592889 
9 0.9918 0.02558985887151185 0.9847 0.05532029465065351 
10 0.9933333333333333 0.02217085763261192 0.9883 0.04032651088784058 
11 0.9948666666666667 0.016906421923649634 0.9855 0.049887262906119965 
12 0.9950833333333333 0.015078015409687288 0.986 0.0537759347953789 
13 0.99385 0.01898173349973801 0.9873 0.04832162987819463 
14 0.9956666666666667 0.013622460813894649 0.9853 0.05445241757461725 
15 0.9963166666666666 0.012382367801236664 0.9857 0.055004828731155454 
16 0.9962833333333333 0.012106772885966649 0.9857 0.05422881894170901 
17 0.9969 0.010083173271658495 0.9874 0.047076658170164265 
18 0.9976166666666667 0.007982039923283439 0.9872 0.050934940750158315 
19 0.9975833333333334 0.007557775232857585 0.9867 0.054406590726153455 
20 0.9976166666666667 0.00812110538679415 0.9874 0.05195093051821152 
21 0.9980166666666667 0.0069773298472639664 0.9879 0.05065196349876068 
22 0.9987 0.004442628666219698 0.9878 0.04847283645632794 
23 0.9992 0.0032568888944515635 0.9888 0.04768979233270545 
24 0.9979333333333333 0.006423923054656333 0.9859 0.05835013131779499 
25 0.9968166666666667 0.01076117260218625 0.9871 0.051060638086830076 
26 0.99855 0.005145495263978198 0.988 0.0503883324544851 
27 0.99925 0.002995247903887565 0.9881 0.054130633886431294 
28 0.9970333333333333 0.00931009006470593 0.9862 0.056396419318201586 
29 0.9987833333333334 0.004270356662851149 0.9893 0.046484103595236234 
30 0.9992833333333333 0.002611983867067573 0.989 0.04604504192029379 
所要時間 = 22 分 34 秒

テストデータの正解率が98.9%になりました。非常に良い結果となりました。

GRU

GRUは、LSTMの改良して、CECを不要、ゲート数を1つ減らして2つにしています。
構造は、下図です。
GRU.png

基本部

基本の構造は、シンプルRNNと同じです。
GRU2.png

LSTMのCECの役割を$ h_t $に兼ねさせています。
下図のように、$ h_{t-1} $から$ h_t $への接続を追加することで、CECと同様に逆伝播時の勾配を留めておくことが可能となります。

GRU3.png

更新ゲート

更新ゲートは、LSTMの忘却ゲート、入力ゲートを兼ねた働きをしています。$ h_{t-1} $と活性化関数後のデータから最終的な$ h_t $を決めます。$ h_{t-1} $方向と活性化関数方向の確率を足して1になるようにします。$ h_{t-1} $には、$ (1-z_t) $を掛けます。
GRU4.png

リセットゲート

リセットゲートも忘却ゲートに近い位置づけです。今までのゲートと異なり、過去からの情報にのみ作用します。値が0であれば、過去からの情報は無視されます。
GRU5.png

順伝播

GRUの式は、以下のようになります。

\begin{align}
z_t &= \mathrm{sigmoid}(h_{t-1} \cdot W_{hz}+u_t \cdot W_{xz} + b_z)\\
r_t &= \mathrm{sigmoid}(h_{t-1} \cdot W_{hr}+u_t \cdot W_{xr} + b_r)\\
n_t &= \mathrm{tanh}(r_t \times (h_{t-1} \cdot W_{hg})+u_t \cdot W_{xg} + b_g)\\
h_t &= (1-z_t) \times h_{t-1} + z_t \times n_t
\end{align}

GRU6.png
実装です。
性能改善のため、transposeを行っています。

def GRU(u, h_0, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh):
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # Wxz, Wxr, Wxn - (d, h)
    # Whz, Whr, Whn - (h, h)
    # bz,  br,  bn  - h
    # 初期設定
    h_t_1 = h_0
    hs = np.zeros((ut, hn, hh))
    zs = np.zeros((ut, hn, hh))
    rs = np.zeros((ut, hn, hh))
    ns = np.zeros((ut, hn, hh))
    for t in range(ut):
        # GRU
        u_t = u[t]
        z_t = sigmoid(np.dot(h_t_1, Whz) + np.dot(u_t, Wxz) + bz)
        r_t = sigmoid(np.dot(h_t_1, Whr) + np.dot(u_t, Wxr) + br)
        n_t = propagation_func(r_t * np.dot(h_t_1, Whn) + np.dot(u_t, Wxn) + bn)

        h_t = (1-z_t) * h_t_1 + z_t * n_t
        # h,z,r,nの値保持
        hs[t] = h_t
        zs[t] = z_t
        rs[t] = r_t
        ns[t] = n_t
        h_t_1 = h_t
    # 後設定
    h_n = h_t
    hs = hs.transpose(1,0,2)
    return h_n, hs, zs, rs, ns

逆伝播

逆伝播を考えます。
$ h_t $から逆向きの矢印にそって、$ u_t, h_{t-1}, W_{xz}, W_{hz}, b_z, W_{xr}, W_{hr}, b_r, W_{xn}, W_{hn}, b_n $まで赤字の勾配を掛けていきます。
GRU7.png

各勾配の求め方は、LSTMを参考にしてください。
$ 1-z_t $の勾配は、$ z_t $で微分した$ -1 $になります。

逆向きの矢印を順番に掛け算していきます。

def GRU_back(dh_n, dhs, u, hs, h_0, zs, rs, ns, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # zs, rs, ns - (n, t, h)
    # Wxz, Wxr, Wxn - (d, h)
    # Whz, Whr, Whn - (h, h)
    # bz,  br,  bn  - h
    # 初期設定
    dh_t  = dh_n
    du  = np.zeros_like(u)
    dWxz = np.zeros_like(Wxz)
    dWhz = np.zeros_like(Whz)
    dbz  = np.zeros_like(bz)
    dWxr = np.zeros_like(Wxr)
    dWhr = np.zeros_like(Whr)
    dbr  = np.zeros_like(br)
    dWxn = np.zeros_like(Wxn)
    dWhn = np.zeros_like(Whn)
    dbn  = np.zeros_like(bn)
    for t in reversed(range(ut)):
        # GRU勾配計算
        u_t = u[t]
        h_t = hs[t]
        if t == 0:
            h_t_1 = h_0
        else:
            h_t_1 = hs[t-1]
        z_t = zs[t]
        r_t = rs[t]
        n_t = ns[t]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]
        # 各勾配
        # n
        dn = dh_t * z_t
        dpn = propagation_back_func(dn, None, n_t)
        dbn  += np.sum(dpn, axis=0)
        dWhn += np.dot(h_t_1.T, dpn * r_t)
        dWxn += np.dot(u_t.T, dpn)
        # r
        dr = dpn * r_t * np.dot(h_t_1, Whn)
        dpr = sigmoid_back(dr, None, r_t)
        dbr  += np.sum(dpr, axis=0)
        dWhr += np.dot(h_t_1.T, dpr) 
        dWxr += np.dot(u_t.T, dpr)
        # z
        dz = dh_t * n_t - dh_t * h_t_1
        dpz = sigmoid_back(dz, None, z_t)
        dbz  += np.sum(dpz, axis=0)
        dWhz += np.dot(h_t_1.T, dpz) 
        dWxz += np.dot(u_t.T, dpz)
        # h
        dh_t_1 = np.dot(dpz, Whz.T) + np.dot(dpr, Whr.T) + np.dot(dpn * r_t, Whn.T) + dh_t * (1 - z_t)
        # u
        du_t = np.dot(dpz, Wxz.T) + np.dot(dpr, Wxr.T) + np.dot(dpn, Wxn.T)
        dh_t = dh_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    du = du.transpose(1,0,2)

    return du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0

フレームワーク対応

「ディープラーニングを実装から学ぶ(8)実装変更」にGRUを組み込みます。
シンプルRNN、LSTMと同様に対応します。

def GRU_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wxz = Wx_init_func(d, h, **Wx_init_params)
    Whz = Wh_init_func(h, h, **Wh_init_params)
    bz = bias_init_func(h, **bias_init_params)
    Wxr = Wx_init_func(d, h, **Wx_init_params)
    Whr = Wh_init_func(h, h, **Wh_init_params)
    br = bias_init_func(h, **bias_init_params)
    Wxn = Wx_init_func(d, h, **Wx_init_params)
    Whn = Wh_init_func(h, h, **Wh_init_params)
    bn = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxz":Wxz, "Whz":Whz, "bz":bz, "Wxr":Wxr, "Whr":Whr, "br":br, "Wxn":Wxn, "Whn":Whn, "bn":bn}

def GRU_init_optimizer():
    sWxz = {}
    sWhz = {}
    sbz = {}
    sWxr = {}
    sWhr = {}
    sbr = {}
    sWxn = {}
    sWhn = {}
    sbn = {}
    return {"sWxz":sWxz, "sWhz":sWhz, "sbz":sbz,
             "sWxr":sWxr, "sWhr":sWhr, "sbr":sbr,
             "sWxn":sWxn, "sWhn":sWhn, "sbn":sbn}

def GRU_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Whz"].shape[0]))
    # GRU
    h_n, hs, zs, rs, ns = func(u, h_0,
                               weights["Wxz"], weights["Whz"], weights["bz"], 
                               weights["Wxr"], weights["Whr"], weights["br"], 
                               weights["Wxn"], weights["Whn"], weights["bn"], 
                               propagation_func)
    # GRU最下層以外
    if params.get("return_hs"):
        z = hs
    # GRU最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wxz"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whz"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxr"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whr"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxn"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whn"], **weight_decay["params"])

    return {"u":u, "z":z, "h_0":h_0, "hs":hs, "zs":zs, "rs":rs, "ns":ns}, weight_decay_r

def GRU_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # GRU最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"]) 
        dhs = dz
    # GRU最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"]) 
    # GRUの勾配計算
    du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0 = back_func(dh_n, dhs,
                                        us["u"], us["hs"], us["h_0"],
                                        us["zs"], us["rs"], us["ns"],
                                        weights["Wxz"], weights["Whz"], weights["bz"],
                                        weights["Wxr"], weights["Whr"], weights["br"],
                                        weights["Wxn"], weights["Whn"], weights["bn"],
                                        propagation_func)

    # 重み減衰対応
    if weight_decay is not None:
        dWxz += weight_decay["back_func"](weights["Wxz"], **weight_decay["params"])
        dWhz += weight_decay["back_func"](weights["Whz"], **weight_decay["params"])
        dWxr += weight_decay["back_func"](weights["Wxr"], **weight_decay["params"])
        dWhr += weight_decay["back_func"](weights["Whr"], **weight_decay["params"])
        dWxn += weight_decay["back_func"](weights["Wxn"], **weight_decay["params"])
        dWhn += weight_decay["back_func"](weights["Whn"], **weight_decay["params"])

    return {"du":du, "dWxz":dWxz, "dWhz":dWhz, "dbz":dbz, "dWxr":dWxr, "dWhr":dWhr, "dbr":dbr, "dWxn":dWxn, "dWhn":dWhn, "dbn":dbn}

def GRU_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wxz"], optimizer_stats["sWxz"] = func(weights["Wxz"], du["dWxz"], **params, **optimizer_stats["sWxz"])
    weights["Whz"], optimizer_stats["sWhz"] = func(weights["Whz"], du["dWhz"], **params, **optimizer_stats["sWhz"])
    weights["bz"],  optimizer_stats["sbz"]  = func(weights["bz"],  du["dbz"],  **params, **optimizer_stats["sbz"])
    weights["Wxr"], optimizer_stats["sWxr"] = func(weights["Wxr"], du["dWxr"], **params, **optimizer_stats["sWxr"])
    weights["Whr"], optimizer_stats["sWhr"] = func(weights["Whr"], du["dWhr"], **params, **optimizer_stats["sWhr"])
    weights["br"],  optimizer_stats["sbr"]  = func(weights["br"],  du["dbr"],  **params, **optimizer_stats["sbr"])
    weights["Wxn"], optimizer_stats["sWxn"] = func(weights["Wxn"], du["dWxn"], **params, **optimizer_stats["sWxn"])
    weights["Whn"], optimizer_stats["sWhn"] = func(weights["Whn"], du["dWhn"], **params, **optimizer_stats["sWhn"])
    weights["bn"],  optimizer_stats["sbn"]  = func(weights["bn"],  du["dbn"],  **params, **optimizer_stats["sbn"])
    return weights, optimizer_stats

実行例

「ディープラーニングを実装から学ぶ(8)実装変更」の「参考」-「プログラム全体」のプログラムを事前に実行しておきます。
MNISTのデータを読み込みます。(MNISTのデータファイルが、c:\mnistに格納されている例)

x_train, t_train, x_test, t_test = load_mnist('c:\\mnist\\')
data_normalizer = create_data_normalizer(min_max)
nx_train, data_normalizer_stats = train_data_normalize(data_normalizer, x_train)
nx_test                         = test_data_normalize(data_normalizer, x_test)

時系列データとなるようにreshapeします。

nx_train = nx_train.reshape((nx_train.shape[0], 28, 28))
nx_test  = nx_test.reshape((nx_test.shape[0], 28, 28))

モデルを定義します。出力次元数は100にします。
GRUの1層です。

model = create_model((28,28))
model = add_layer(model, "GRU1", GRU, 100)
model = add_layer(model, "affine", affine, 10)
model = set_output(model, softmax)
model = set_error(model, cross_entropy_error)
optimizer = create_optimizer(SGD, lr=0.5)

30エポック実行します。

epoch = 30
batch_size = 100
np.random.seed(10)
model, optimizer, learn_info = learn(model, nx_train, t_train, nx_test, t_test, batch_size=batch_size, epoch=epoch, optimizer=optimizer)

結果です。

input - 0 (28, 28)
GRU1 GRU (28, 28) 100
affine affine 100 10
output softmax 10
error cross_entropy_error
0 0.10721666666666667 2.307061198415644 0.1101 2.306898267580451 
1 0.82675 0.511526219928285 0.9635 0.11362437468398816 
2 0.9676833333333333 0.10288470755828825 0.9768 0.07542646599764968 
3 0.9792666666666666 0.06695928743992582 0.9774 0.06960452252927776 
4 0.9842 0.051180737471969766 0.9823 0.05597069929362484 
5 0.9872666666666666 0.041436770227346216 0.9836 0.052709572823328596 
6 0.9885 0.03643804579236557 0.9875 0.04349536721163362 
7 0.9912166666666666 0.028428202556455142 0.9867 0.044733059747913924 
8 0.9923666666666666 0.02460958644499428 0.9876 0.03991215243673695 
9 0.9934333333333333 0.021230599283107825 0.9883 0.04347800317454341 
10 0.9944166666666666 0.01776144629697566 0.9867 0.041893483495723575 
11 0.9949 0.016518808143304944 0.9874 0.04536614405376183 
12 0.9959833333333333 0.013041504974422278 0.9882 0.04449759201360954 
13 0.9961833333333333 0.012409084573850067 0.9895 0.03990957338832359 
14 0.99685 0.010596160995199305 0.9872 0.04594857842476447 
15 0.9974 0.009251340895075295 0.9882 0.04591789337108098 
16 0.9976666666666667 0.008386733084813818 0.9895 0.04258350902712381 
17 0.9979333333333333 0.007238187314384925 0.9898 0.04240039691192806 
18 0.9985166666666667 0.006119401627858776 0.9885 0.045598807564551626 
19 0.9987666666666667 0.00526031851158144 0.9883 0.04853607898008629 
20 0.9988833333333333 0.0047030283644943155 0.9897 0.042247764684459545 
21 0.9994666666666666 0.0028767811759262936 0.9897 0.041643354883898524 
22 0.9997166666666667 0.0019103182051926498 0.9905 0.040680397434331175 
23 0.9995166666666667 0.002273298935623554 0.989 0.04508413592309316 
24 0.9994 0.002625169443319339 0.9895 0.04497260231402457 
25 0.9998166666666667 0.0014577561763508475 0.9896 0.04607536126924782 
26 0.9980833333333333 0.00678067231148907 0.9882 0.051510202169538846 
27 0.99735 0.008093889267855566 0.9885 0.0485818785663622 
28 0.9994666666666666 0.002663875703525538 0.9907 0.042034645596668634 
29 0.9989333333333333 0.004130722570596547 0.9885 0.04660574722972972 
30 0.9986833333333334 0.0043773158680422315 0.9887 0.04676857282237599 
所要時間 = 14 分 38 秒

一時的ですが、テストデータの正解率が99%を超えました。LSTMに比べて性能がよく、GRUの精度で問題がない場合は、GRUを利用する方が良いかと思います。

次回、ハイパーパラメータを変更し精度を確認していきます。

参考

RNN全体のプログラムです。
ここで、propagation_funcとしてReLU関数も指定できるように変更しておきます。

ReLU

propagation_funcの逆伝播関数では、uの値は利用できません。逆伝播時に、zのみ利用するように変更します。変更は、勾配計算のuをzに変更するのみです。

def relu_back(dz, u, z):
    return dz * np.where(z > 0, 1, 0)

関数仕様

# 層追加関数
model = add_layer(model, name, func, d=None, **kwargs)
# 引数
#  model : モデル
#  name  : レイヤの名前
#  func  : 中間層の関数
#           affine,sigmoid,tanh,relu,leaky_relu,prelu,rrelu,relun,srelu,elu,maxout,
#           identity,softplus,softsign,step,dropout,batch_normalization,
#           convolution2d,max_pooling2d,average_pooling2d,flatten2d,
#           RNN,LSTM,GRU
#  d     : ノード数
#           affine,maxout,convolution2d,RNN,LSTM,GRUの場合指定
#           convolution2dの場合は、フィルタ数
#           RNN,LSTM,GRUの場合は、出力次元数
#  kwargs: 中間層の関数のパラメータ
#           affine - weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#                     weight_init_func - lecun_normal,lecun_uniform,glorot_normal,glorot_uniform,he_normal,he_uniform,normal_w,uniform_w,zeros_w,ones_w
#                     weight_init_params
#                       normal_w - mean=0, var=1
#                       uniform_w - min=0, max=1
#                     bias_init_func - normal_b,uniform_b,zeros_b,ones_b
#                     bias_init_params
#                       normal_b - mean=0, var=1
#                       uniform_b - min=0, max=1
#           leaky_relu - alpha
#           rrelu - min=0.0, max=0.1
#           relun - n
#           srelu - alpha
#           elu - alpha
#           maxout - unit=1, weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#           dropout - dropout_ratio=0.9
#           batch_normalization - batch_norm_node=True, use_gamma_beta=True, use_train_stats=True, alpha=0.1
#           convolution2d - padding=0, strides=1, weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#           max_pooling2d - pool_size=2, padding=0, strides=None
#           average_pooling2d - pool_size=2, padding=0, strides=None
#           RNN - propagation_func=tanh, return_hs=False, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#                  propagation_func - tanh,sigmoid,relu
#                  Wx_init_func, Wh_init_funcは、affineのweight_init_funcに同じ
#           LSTM - propagation_func=tanh, return_hs=False, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#                  propagation_func - tanh,sigmoid,relu
#                  Wx_init_func, Wh_init_funcは、affineのweight_init_funcに同じ
#           GRU - propagation_func=tanh, return_hs=False, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#                  propagation_func - tanh,sigmoid,relu
#                  Wx_init_func, Wh_init_funcは、affineのweight_init_funcに同じ
# 戻り値
#  モデル

プログラム

「ディープラーニングを実装から学ぶ(8)実装変更」の追加分です。

def RNN(u, h_0, Wx, Wh, b, propagation_func=tanh):
    # u - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    h_t_1 = h_0
    hs = np.zeros((ut, hn, hh))
    for t in range(ut):
        # RNN
        u_t = u[t]
        h_t = propagation_func(np.dot(h_t_1, Wh) + np.dot(u_t, Wx) + b)
        h_t_1 = h_t
        # hの値保持
        hs[t] = h_t
    # 後設定
    h_n = h_t
    hs = hs.transpose(1,0,2)
    return h_n, hs

def RNN_back(dh_n, dhs, u, hs, h_0, Wx, Wh, b, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    dh_t  = dh_n
    du  = np.zeros_like(u)
    dWx = np.zeros_like(Wx)
    dWh = np.zeros_like(Wh)
    db  = np.zeros_like(b)
    for t in reversed(range(ut)):
        # RNN勾配計算
        u_t = u[t]
        h_t = hs[t]
        if t == 0:
            h_t_1 = h_0
        else:
            h_t_1 = hs[t-1]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]
        # tanhの勾配
        dp = propagation_back_func(dh_t, None, h_t)
        # 各勾配
        # h =>  (n, h)
        db  += np.sum(dp, axis=0)
        # (h, h) => (h, n) @ (n, h)
        dWh += np.dot(h_t_1.T, dp)
        # (d, h) => (d, n) @ (n, h)
        dWx += np.dot(u_t.T, dp)
        # (n, h) => (n, h) @ (h, h)
        dh_t_1 = np.dot(dp, Wh.T)
        # (n, d) => (n, h) @ (h, d)
        du_t = np.dot(dp, Wx.T)
        dh_t = dh_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    du = du.transpose(1,0,2)
    return du, dWx, dWh, db, dh_0

def LSTM(u, h_0, c_0, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh):
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # c_0  - (n, h)
    cn, ch = c_0.shape
    # Wxf, Wxi, Wxg, Who - (d, h)
    # Whf, Whi, Whg, Who - (h, h)
    # bf,  bi,  bg,  bo  - h
    # 初期設定
    h_t_1 = h_0
    c_t_1 = c_0
    hs = np.zeros((ut, hn, hh))
    cs = np.zeros((ut, cn, ch))
    fs = np.zeros((ut, hn, hh))
    Is = np.zeros((ut, hn, hh))
    gs = np.zeros((ut, hn, hh))
    os = np.zeros((ut, hn, hh))
    for t in range(ut):
        # LSTM
        u_t = u[t]
        f_t = sigmoid(np.dot(h_t_1, Whf) + np.dot(u_t, Wxf) + bf)
        i_t = sigmoid(np.dot(h_t_1, Whi) + np.dot(u_t, Wxi) + bi)
        g_t = propagation_func(np.dot(h_t_1, Whg) + np.dot(u_t, Wxg) + bg)
        o_t = sigmoid(np.dot(h_t_1, Who) + np.dot(u_t, Wxo) + bo)

        c_t = f_t * c_t_1 + i_t * g_t
        h_t = o_t * propagation_func(c_t)
        # h,c,f,i,g,oの値保持
        hs[t] = h_t
        cs[t] = c_t
        fs[t] = f_t
        Is[t] = i_t
        gs[t] = g_t
        os[t] = o_t
        h_t_1 = h_t
        c_t_1 = c_t
    # 後設定
    h_n = h_t
    c_n = c_t
    hs = hs.transpose(1,0,2)
    cs = cs.transpose(1,0,2)
    return h_n, hs, c_n, cs, fs, Is, gs, os

def LSTM_back(dh_n, dhs, u, hs, h_0, cs, c_0, fs, Is, gs, os, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # cs - (n, t, h)
    cs = cs.transpose(1,0,2)
    # c_0 - (n, h)
    # fs, Is, gs, os - (n, t, h)
    # Wxf, Wxi, Wxg, Who - (d, h)
    # Whf, Whi, Whg, Who - (h, h)
    # bf,  bi,  bg,  bo  - h
    # 初期設定
    dh_t  = dh_n
    dc_t  = np.zeros_like(c_0)
    du  = np.zeros_like(u)
    dWxf = np.zeros_like(Wxf)
    dWhf = np.zeros_like(Whf)
    dbf  = np.zeros_like(bf)
    dWxi = np.zeros_like(Wxi)
    dWhi = np.zeros_like(Whi)
    dbi  = np.zeros_like(bi)
    dWxg = np.zeros_like(Wxg)
    dWhg = np.zeros_like(Whg)
    dbg  = np.zeros_like(bg)
    dWxo = np.zeros_like(Wxo)
    dWho = np.zeros_like(Who)
    dbo  = np.zeros_like(bo)
    for t in reversed(range(ut)):
        # LSTM勾配計算
        u_t = u[t]
        h_t = hs[t]
        c_t = cs[t]
        if t == 0:
            h_t_1 = h_0
            c_t_1 = c_0
        else:
            h_t_1 = hs[t-1]
            c_t_1 = cs[t-1]
        o_t = os[t]
        g_t = gs[t]
        i_t = Is[t]
        f_t = fs[t]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]        
        # tanhの勾配
        dp = dh_t * o_t
        p_t = propagation_func(c_t)
        dc_t = dc_t + propagation_back_func(dp, None, p_t)
        # 各勾配
        # o
        do = dh_t * p_t
        dpo = sigmoid_back(do, None, o_t)
        dbo  += np.sum(dpo, axis=0)
        dWho += np.dot(h_t_1.T, dpo) 
        dWxo += np.dot(u_t.T, dpo)
        # g
        dg = dc_t * i_t
        dpg = propagation_back_func(dg, None, g_t)
        dbg  += np.sum(dpg, axis=0)
        dWhg += np.dot(h_t_1.T, dpg) 
        dWxg += np.dot(u_t.T, dpg)
        # i
        di = dc_t * g_t
        dpi = sigmoid_back(di, None, i_t)
        dbi  += np.sum(dpi, axis=0)
        dWhi += np.dot(h_t_1.T, dpi) 
        dWxi += np.dot(u_t.T, dpi)
        # f
        df = dc_t * c_t_1
        dpf = sigmoid_back(df, None, f_t)
        dbf  += np.sum(dpf, axis=0)
        dWhf += np.dot(h_t_1.T, dpf) 
        dWxf += np.dot(u_t.T, dpf)
        # c
        dc_t_1 = dc_t * f_t
        # h
        dh_t_1 = np.dot(dpf, Whf.T) + np.dot(dpi, Whi.T) + np.dot(dpg, Whg.T) + np.dot(dpo, Who.T)
        # u
        du_t = np.dot(dpf, Wxf.T) + np.dot(dpi, Wxi.T) + np.dot(dpg, Wxg.T) + np.dot(dpo, Wxo.T)
        dh_t = dh_t_1
        dc_t = dc_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    dc_0 = dc_t_1
    du = du.transpose(1,0,2)

    return du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0

def GRU(u, h_0, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh):
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # Wxz, Wxr, Wxn - (d, h)
    # Whz, Whr, Whn - (h, h)
    # bz,  br,  bn  - h
    # 初期設定
    h_t_1 = h_0
    hs = np.zeros((ut, hn, hh))
    zs = np.zeros((ut, hn, hh))
    rs = np.zeros((ut, hn, hh))
    ns = np.zeros((ut, hn, hh))
    for t in range(ut):
        # GRU
        u_t = u[t]
        z_t = sigmoid(np.dot(h_t_1, Whz) + np.dot(u_t, Wxz) + bz)
        r_t = sigmoid(np.dot(h_t_1, Whr) + np.dot(u_t, Wxr) + br)
        n_t = propagation_func(r_t * np.dot(h_t_1, Whn) + np.dot(u_t, Wxn) + bn)

        h_t = (1-z_t) * h_t_1 + z_t * n_t
        # h,z,r,nの値保持
        hs[t] = h_t
        zs[t] = z_t
        rs[t] = r_t
        ns[t] = n_t
        h_t_1 = h_t
    # 後設定
    h_n = h_t
    hs = hs.transpose(1,0,2)
    return h_n, hs, zs, rs, ns

def GRU_back(dh_n, dhs, u, hs, h_0, zs, rs, ns, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # zs, rs, ns - (n, t, h)
    # Wxz, Wxr, Wxn - (d, h)
    # Whz, Whr, Whn - (h, h)
    # bz,  br,  bn  - h
    # 初期設定
    dh_t  = dh_n
    du  = np.zeros_like(u)
    dWxz = np.zeros_like(Wxz)
    dWhz = np.zeros_like(Whz)
    dbz  = np.zeros_like(bz)
    dWxr = np.zeros_like(Wxr)
    dWhr = np.zeros_like(Whr)
    dbr  = np.zeros_like(br)
    dWxn = np.zeros_like(Wxn)
    dWhn = np.zeros_like(Whn)
    dbn  = np.zeros_like(bn)
    for t in reversed(range(ut)):
        # GRU勾配計算
        u_t = u[t]
        h_t = hs[t]
        if t == 0:
            h_t_1 = h_0
        else:
            h_t_1 = hs[t-1]
        z_t = zs[t]
        r_t = rs[t]
        n_t = ns[t]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]
        # 各勾配
        # n
        dn = dh_t * z_t
        dpn = propagation_back_func(dn, None, n_t)
        dbn  += np.sum(dpn, axis=0)
        dWhn += np.dot(h_t_1.T, dpn * r_t)
        dWxn += np.dot(u_t.T, dpn)
        # r
        dr = dpn * r_t * np.dot(h_t_1, Whn)
        dpr = sigmoid_back(dr, None, r_t)
        dbr  += np.sum(dpr, axis=0)
        dWhr += np.dot(h_t_1.T, dpr) 
        dWxr += np.dot(u_t.T, dpr)
        # z
        dz = dh_t * n_t - dh_t * h_t_1
        dpz = sigmoid_back(dz, None, z_t)
        dbz  += np.sum(dpz, axis=0)
        dWhz += np.dot(h_t_1.T, dpz) 
        dWxz += np.dot(u_t.T, dpz)
        # h
        dh_t_1 = np.dot(dpz, Whz.T) + np.dot(dpr, Whr.T) + np.dot(dpn * r_t, Whn.T) + dh_t * (1 - z_t)
        # u
        du_t = np.dot(dpz, Wxz.T) + np.dot(dpr, Wxr.T) + np.dot(dpn, Wxn.T)
        dh_t = dh_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    du = du.transpose(1,0,2)

    return du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0
def RNN_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wx = Wx_init_func(d, h, **Wx_init_params)
    Wh = Wh_init_func(h, h, **Wh_init_params)
    b = bias_init_func(h, **bias_init_params)
    return d_next, {"Wx":Wx, "Wh":Wh, "b":b}

def RNN_init_optimizer():
    sWx = {}
    sWh = {}
    sb = {}
    return {"sWx":sWx, "sWh":sWh, "sb":sb}

def RNN_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Wh"].shape[0]))
    # RNN
    h_n, hs = func(u, h_0, weights["Wx"], weights["Wh"], weights["b"], propagation_func)
    # RNN最下層以外    
    if params.get("return_hs"):
        z = hs
    # RNN最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wx"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wh"], **weight_decay["params"])
    return {"u":u, "z":z, "h_0":h_0, "hs":hs}, weight_decay_r

def RNN_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # RNN最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"])
        dhs = dz
    # RNN最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"])
    # RNNの勾配計算
    du, dWx, dWh, db, dh_0 = back_func(dh_n, dhs, us["u"], us["hs"], us["h_0"], weights["Wx"], weights["Wh"], weights["b"], propagation_func)
    # 重み減衰対応
    if weight_decay is not None:
        dWx += weight_decay["back_func"](weights["Wx"], **weight_decay["params"])
        dWh += weight_decay["back_func"](weights["Wh"], **weight_decay["params"])
    return {"du":du, "dWx":dWx, "dWh":dWh, "db":db}

def RNN_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wx"], optimizer_stats["sWx"] = func(weights["Wx"], du["dWx"], **params, **optimizer_stats["sWx"])
    weights["Wh"], optimizer_stats["sWh"] = func(weights["Wh"], du["dWh"], **params, **optimizer_stats["sWh"])
    weights["b"],  optimizer_stats["sb"]  = func(weights["b"],  du["db"],  **params, **optimizer_stats["sb"])
    return weights, optimizer_stats

def LSTM_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wxf = Wx_init_func(d, h, **Wx_init_params)
    Whf = Wh_init_func(h, h, **Wh_init_params)
    bf = bias_init_func(h, **bias_init_params)
    Wxi = Wx_init_func(d, h, **Wx_init_params)
    Whi = Wh_init_func(h, h, **Wh_init_params)
    bi = bias_init_func(h, **bias_init_params)
    Wxg = Wx_init_func(d, h, **Wx_init_params)
    Whg = Wh_init_func(h, h, **Wh_init_params)
    bg = bias_init_func(h, **bias_init_params)
    Wxo = Wx_init_func(d, h, **Wx_init_params)
    Who = Wh_init_func(h, h, **Wh_init_params)
    bo = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxf":Wxf, "Whf":Whf, "bf":bf, "Wxi":Wxi, "Whi":Whi, "bi":bi, "Wxg":Wxg, "Whg":Whg, "bg":bg, "Wxo":Wxo, "Who":Who, "bo":bo}

def LSTM_init_optimizer():
    sWxf = {}
    sWhf = {}
    sbf = {}
    sWxi = {}
    sWhi = {}
    sbi = {}
    sWxg = {}
    sWhg = {}
    sbg = {}
    sWxo = {}
    sWho = {}
    sbo = {}
    return {"sWxf":sWxf, "sWhf":sWhf, "sbf":sbf,
             "sWxi":sWxi, "sWhi":sWhi, "sbi":sbi,
             "sWxg":sWxg, "sWhg":sWhg, "sbg":sbg,
             "sWxo":sWxo, "sWho":sWho, "sbo":sbo}

def LSTM_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Whf"].shape[0]))
    c_0 = params.get("c_0")
    if c_0 is None:
        c_0 = np.zeros((u.shape[0], weights["Whf"].shape[0]))
    # LSTM
    h_n, hs, c_n, cs, fs, Is, gs, os = func(u, h_0, c_0,
                                        weights["Wxf"], weights["Whf"], weights["bf"], 
                                        weights["Wxi"], weights["Whi"], weights["bi"], 
                                        weights["Wxg"], weights["Whg"], weights["bg"], 
                                        weights["Wxo"], weights["Who"], weights["bo"], propagation_func)
    # LSTM最下層以外 
    if params.get("return_hs"):
        z = hs
    # LSTM最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wxf"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whf"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxi"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whi"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxg"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whg"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxo"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Who"], **weight_decay["params"])

    return {"u":u, "z":z, "h_0":h_0, "hs":hs, "c_0":c_0, "cs":cs,
             "fs":fs, "Is":Is, "gs":gs, "os":os}, weight_decay_r

def LSTM_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # LSTM最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"]) 
        dhs = dz
    # LSTM最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"]) 
    # LSTMの勾配計算
    du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0 = back_func(dh_n, dhs,
                                        us["u"], us["hs"], us["h_0"], us["cs"], us["c_0"],
                                        us["fs"], us["Is"], us["gs"], us["os"],
                                        weights["Wxf"], weights["Whf"], weights["bf"],
                                        weights["Wxi"], weights["Whi"], weights["bi"],
                                        weights["Wxg"], weights["Whg"], weights["bg"],
                                        weights["Wxo"], weights["Who"], weights["bo"],
                                        propagation_func)

    # 重み減衰対応
    if weight_decay is not None:
        dWxf += weight_decay["back_func"](weights["Wxf"], **weight_decay["params"])
        dWhf += weight_decay["back_func"](weights["Whf"], **weight_decay["params"])
        dWxi += weight_decay["back_func"](weights["Wxi"], **weight_decay["params"])
        dWhi += weight_decay["back_func"](weights["Whi"], **weight_decay["params"])
        dWxg += weight_decay["back_func"](weights["Wxg"], **weight_decay["params"])
        dWhg += weight_decay["back_func"](weights["Whg"], **weight_decay["params"])
        dWxo += weight_decay["back_func"](weights["Wxo"], **weight_decay["params"])
        dWho += weight_decay["back_func"](weights["Who"], **weight_decay["params"])

    return {"du":du, "dWxf":dWxf, "dWhf":dWhf, "dbf":dbf, "dWxi":dWxi, "dWhi":dWhi, "dbi":dbi, "dWxg":dWxg, "dWhg":dWhg, "dbg":dbg, "dWxo":dWxo, "dWho":dWho, "dbo":dbo}

def LSTM_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wxf"], optimizer_stats["sWxf"] = func(weights["Wxf"], du["dWxf"], **params, **optimizer_stats["sWxf"])
    weights["Whf"], optimizer_stats["sWhf"] = func(weights["Whf"], du["dWhf"], **params, **optimizer_stats["sWhf"])
    weights["bf"],  optimizer_stats["sbf"]  = func(weights["bf"],  du["dbf"],  **params, **optimizer_stats["sbf"])
    weights["Wxi"], optimizer_stats["sWxi"] = func(weights["Wxi"], du["dWxi"], **params, **optimizer_stats["sWxi"])
    weights["Whi"], optimizer_stats["sWhi"] = func(weights["Whi"], du["dWhi"], **params, **optimizer_stats["sWhi"])
    weights["bi"],  optimizer_stats["sbi"]  = func(weights["bi"],  du["dbi"],  **params, **optimizer_stats["sbi"])
    weights["Wxg"], optimizer_stats["sWxg"] = func(weights["Wxg"], du["dWxg"], **params, **optimizer_stats["sWxg"])
    weights["Whg"], optimizer_stats["sWhg"] = func(weights["Whg"], du["dWhg"], **params, **optimizer_stats["sWhg"])
    weights["bg"],  optimizer_stats["sbg"]  = func(weights["bg"],  du["dbg"],  **params, **optimizer_stats["sbg"])
    weights["Wxo"], optimizer_stats["sWxo"] = func(weights["Wxo"], du["dWxo"], **params, **optimizer_stats["sWxo"])
    weights["Who"], optimizer_stats["sWho"] = func(weights["Who"], du["dWho"], **params, **optimizer_stats["sWho"])
    weights["bo"],  optimizer_stats["sbo"]  = func(weights["bo"],  du["dbo"],  **params, **optimizer_stats["sbo"])
    return weights, optimizer_stats

def GRU_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wxz = Wx_init_func(d, h, **Wx_init_params)
    Whz = Wh_init_func(h, h, **Wh_init_params)
    bz = bias_init_func(h, **bias_init_params)
    Wxr = Wx_init_func(d, h, **Wx_init_params)
    Whr = Wh_init_func(h, h, **Wh_init_params)
    br = bias_init_func(h, **bias_init_params)
    Wxn = Wx_init_func(d, h, **Wx_init_params)
    Whn = Wh_init_func(h, h, **Wh_init_params)
    bn = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxz":Wxz, "Whz":Whz, "bz":bz, "Wxr":Wxr, "Whr":Whr, "br":br, "Wxn":Wxn, "Whn":Whn, "bn":bn}

def GRU_init_optimizer():
    sWxz = {}
    sWhz = {}
    sbz = {}
    sWxr = {}
    sWhr = {}
    sbr = {}
    sWxn = {}
    sWhn = {}
    sbn = {}
    return {"sWxz":sWxz, "sWhz":sWhz, "sbz":sbz,
             "sWxr":sWxr, "sWhr":sWhr, "sbr":sbr,
             "sWxn":sWxn, "sWhn":sWhn, "sbn":sbn}

def GRU_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Whz"].shape[0]))
    # GRU
    h_n, hs, zs, rs, ns = func(u, h_0,
                               weights["Wxz"], weights["Whz"], weights["bz"], 
                               weights["Wxr"], weights["Whr"], weights["br"], 
                               weights["Wxn"], weights["Whn"], weights["bn"], 
                               propagation_func)
    # GRU最下層以外
    if params.get("return_hs"):
        z = hs
    # GRU最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wxz"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whz"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxr"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whr"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxn"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whn"], **weight_decay["params"])

    return {"u":u, "z":z, "h_0":h_0, "hs":hs, "zs":zs, "rs":rs, "ns":ns}, weight_decay_r

def GRU_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # GRU最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"]) 
        dhs = dz
    # GRU最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"]) 
    # GRUの勾配計算
    du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0 = back_func(dh_n, dhs,
                                        us["u"], us["hs"], us["h_0"],
                                        us["zs"], us["rs"], us["ns"],
                                        weights["Wxz"], weights["Whz"], weights["bz"],
                                        weights["Wxr"], weights["Whr"], weights["br"],
                                        weights["Wxn"], weights["Whn"], weights["bn"],
                                        propagation_func)

    # 重み減衰対応
    if weight_decay is not None:
        dWxz += weight_decay["back_func"](weights["Wxz"], **weight_decay["params"])
        dWhz += weight_decay["back_func"](weights["Whz"], **weight_decay["params"])
        dWxr += weight_decay["back_func"](weights["Wxr"], **weight_decay["params"])
        dWhr += weight_decay["back_func"](weights["Whr"], **weight_decay["params"])
        dWxn += weight_decay["back_func"](weights["Wxn"], **weight_decay["params"])
        dWhn += weight_decay["back_func"](weights["Whn"], **weight_decay["params"])

    return {"du":du, "dWxz":dWxz, "dWhz":dWhz, "dbz":dbz, "dWxr":dWxr, "dWhr":dWhr, "dbr":dbr, "dWxn":dWxn, "dWhn":dWhn, "dbn":dbn}

def GRU_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wxz"], optimizer_stats["sWxz"] = func(weights["Wxz"], du["dWxz"], **params, **optimizer_stats["sWxz"])
    weights["Whz"], optimizer_stats["sWhz"] = func(weights["Whz"], du["dWhz"], **params, **optimizer_stats["sWhz"])
    weights["bz"],  optimizer_stats["sbz"]  = func(weights["bz"],  du["dbz"],  **params, **optimizer_stats["sbz"])
    weights["Wxr"], optimizer_stats["sWxr"] = func(weights["Wxr"], du["dWxr"], **params, **optimizer_stats["sWxr"])
    weights["Whr"], optimizer_stats["sWhr"] = func(weights["Whr"], du["dWhr"], **params, **optimizer_stats["sWhr"])
    weights["br"],  optimizer_stats["sbr"]  = func(weights["br"],  du["dbr"],  **params, **optimizer_stats["sbr"])
    weights["Wxn"], optimizer_stats["sWxn"] = func(weights["Wxn"], du["dWxn"], **params, **optimizer_stats["sWxn"])
    weights["Whn"], optimizer_stats["sWhn"] = func(weights["Whn"], du["dWhn"], **params, **optimizer_stats["sWhn"])
    weights["bn"],  optimizer_stats["sbn"]  = func(weights["bn"],  du["dbn"],  **params, **optimizer_stats["sbn"])
    return weights, optimizer_stats
7
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
13