1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ディープラーニングを実装から学ぶ(10-2)RNNの実装(双方向RNN・orthogonal(重みの初期値))

Posted at

RNN,LSTM,GRUについて、ハイパーパラメータの確認を行う予定でしたが、その前に、双方向RNNと重みの初期値としてorthogonalの確認を行います。
ディープラーニングを実装から学ぶ(10-1)RNNの実装(RNN,LSTM,GRU)の続きです。
例によって、MNISTで確認していきます。

orthogonal(重みの初期値)

重みの初期値として、活性化関数が$ \mathrm{tanh} $のため、glorot_normalを利用してきました。RNNでは、時刻分、重みの内積を取ることになります。
同じ値を何度もかけるため、勾配爆発や勾配消失の原因になります。
初期値として直交行列を用いることで勾配爆発や勾配消失を軽減します。

直行行列

直行行列は、「転置行列が逆行列と等しい」行列です。
$ UU^T=I $($ I $は、単位行列)が成り立ちます。
直行行列は、特異値分解により求めることができます。
行列Aに対して、直交行列U、V、対角行列Dに分解します。

A=UDV^T

numpyのlinalg.svdにて特異値分解を行うことができます。
linalg.svdでは、U,D,Vが返却されるので、一つ目の戻り値(U)を取得します。元の行列は乱数で求めます。
重みの初期値を直交行列にするための関数です。

def orthogonal(d_1, d):
    rand = np.random.normal(0., 1., (d_1, d))
    return np.linalg.svd(rand)[0]

行列確認

3$\times$3の行列を生成してみます。

orth = orthogonal(3,3)
orth

以下のような行列が生成されました。

array([[-0.60393137, -0.69761452,  0.38548786],
       [ 0.78180171, -0.42438299,  0.45682071],
       [-0.15509027,  0.57726342,  0.80169442]])

本当に直交行列かどうか確認するため、転置行列と内積を取ります。

np.dot(orth, orth.T)

結果です。ちゃんと単位行列になっていますね。

array([[ 1.00000000e+00, -3.53711475e-17, -2.99764250e-16],
       [-3.53711475e-17,  1.00000000e+00, -8.39352672e-17],
       [-2.99764250e-16, -8.39352672e-17,  1.00000000e+00]])

勾配爆発・勾配消失確認

3$ \times $3の直交行列を生成し、MNISTの28時刻分に相当する28回内積を取ってみます。
5回分実行してみます。

for i in range(5):
    orth = orthogonal(3,3)
    x = orth
    print("init ", i+1)
    print(x)
    for t in range(28):
        x = np.dot(orth, x)
    print("end ", i+1)
    print(x)

28回内積を取ってもそれなりの値になりました。

init  1
[[-0.80372635  0.51087605 -0.30500757]
 [-0.36638453 -0.02103647  0.93022569]
 [ 0.46881375  0.85939695  0.20408465]]
end  1
[[-0.76245687 -0.40800029  0.50219048]
 [ 0.54918583  0.00232497  0.83569703]
 [-0.34213221  0.91297884  0.22229529]]
init  2
[[-0.89425394 -0.41274405  0.17306717]
 [ 0.44739469 -0.83489191  0.32061424]
 [ 0.01216076  0.36413988  0.93126487]]
end  2
[[-0.82989029 -0.5243296   0.19068451]
 [ 0.55780287 -0.77254528  0.30336405]
 [-0.01175033  0.35812325  0.93360038]]
init  3
[[-0.80110215  0.35867688 -0.47915157]
 [ 0.21217589  0.91874468  0.33300092]
 [ 0.55965769  0.16510335 -0.81211092]]
end  3
[[ 0.86466244 -0.048639    0.49999311]
 [ 0.09153374  0.99389435 -0.06160838]
 [-0.49394375  0.09903669  0.86383523]]
init  4
[[-0.55059322  0.73411648  0.39739161]
 [-0.81460012 -0.36847745 -0.44794086]
 [-0.18241092 -0.57034846  0.80089256]]
end  4
[[ 0.99731877 -0.06860368 -0.02547166]
 [ 0.06846451  0.99763368 -0.0062972 ]
 [ 0.0258434   0.00453641  0.99965571]]
init  5
[[ 0.0411211   0.88763312 -0.45871178]
 [-0.7350467  -0.28409247 -0.61562799]
 [-0.67676835  0.36248988  0.6407696 ]]
end  5
[[-0.1254555  -0.6334718  -0.7635276 ]
 [ 0.81256556 -0.50716522  0.28726407]
 [-0.56920833 -0.58437737  0.57836404]]

次に、glorot_normalで確認してみます。
5回実行します。

for i in range(5):
    glorot = glorot_normal(3,3)
    x = glorot
    print("init ", i+1)
    print(x)
    for t in range(28):
        x = np.dot(glorot, x)
    print("end ", i+1)
    print(x)

1回目は、かなり大きな値になりました。
2、3回目は、小さな値になりました。
4、5回目はまずまずの値になりました。
乱数の初期値によっては、勾配爆発、勾配消失が発生する可能性があります。

init  1
[[ 0.36042833  0.04654217 -0.27553851]
 [-0.1003273   0.99703076 -1.2884802 ]
 [-0.48491302 -0.78957295  1.07022627]]
end  1
[[ 4.27113723e+07  1.31897155e+08 -1.79150247e+08]
 [ 2.60558339e+08  8.04631220e+08 -1.09289607e+09]
 [-2.23244480e+08 -6.89402146e+08  9.36385362e+08]]
init  2
[[-0.54241806  0.13369207 -0.32437354]
 [ 0.50866409 -0.82871844 -0.16365881]
 [ 0.83189728 -0.44774997 -0.67062721]]
end  2
[[-9.18237793e-04  8.41680925e-04 -5.67926628e-04]
 [-4.33933433e-04  7.32992609e-04 -8.52003416e-04]
 [-4.22035207e-05  6.39512261e-04 -1.06920879e-03]]
init  3
[[-0.23330056  0.48466714  0.91377187]
 [-1.65115943  0.01144839 -0.56982934]
 [ 0.71582823  0.12634886  0.10495938]]
end  3
[[-3.59039953e-05 -2.56355152e-05 -5.89428986e-05]
 [ 8.69562293e-05 -6.92291595e-05 -1.05409878e-05]
 [-4.59737230e-05  2.65928349e-06 -3.61461589e-05]]
init  4
[[ 0.74781561 -0.14005785 -0.51149799]
 [-0.34702255 -0.63936539  0.45400338]
 [-0.14042456 -0.64410887 -0.76022184]]
end  4
[[ 0.00129885  0.00275162  0.00203506]
 [ 0.00096465  0.00442726 -0.0051897 ]
 [ 0.00216138  0.00916843  0.00627747]]
init  5
[[ 0.03147282 -0.77472248  0.06427335]
 [-0.38427955  0.57295038  0.24094615]
 [-0.31341085  0.73580383  0.05397258]]
end  5
[[ 1.25692346 -2.88591353 -0.60767576]
 [-1.8194483   4.17747828  0.87963559]
 [-1.71327251  3.93369721  0.82830355]]

RNNにて検証

orthogonalは、$ Wh $のみに適用します。
ディープラーニングを実装から学ぶ(10-1)RNNの実装(RNN,LSTM,GRU) シンプルRNN 実行例のプログラムに初期値関数をorthogonalを設定し試してみます。

model = create_model((28,28))
model = add_layer(model, "RNN1", RNN, 100, Wh_init_func=orthogonal)
model = add_layer(model, "affine", affine, 10)
model = set_output(model, softmax)
model = set_error(model, cross_entropy_error)
optimizer = create_optimizer(SGD, lr=0.01)

epoch = 30
batch_size = 100
np.random.seed(10)
model, optimizer, learn_info = learn(model, nx_train, t_train, nx_test, t_test, batch_size=batch_size, epoch=epoch, optimizer=optimizer)

結果を比較します。

$ Wh $の初期値関数 学習正解 テスト正解 テスト最高
glorot_normal 97.86 97.35 97.35
orthogonal 98.27 97.34 97.54

30エポック後のテスト正解率は、ほぼ同等でしたが、途中の正解率や誤差からorthogonalが良さそうなため、今後は、$ Wh $の初期値関数の既定値は、orthogonalにします。

双方向RNN

MNISTを上部からピクセルごとに時系列と見立てて予測してきました。
時系列のよっては、逆順(新しい方)から確認した方がよい場合があります。
数字では、上から確認した方がより特徴を捉えられるか、下から確認した方がより特徴が捉えられるか、どちらでしょうかね。

上部から順に確認
MNIST_time.png
下部から順に確認(逆方向から確認)
MNIST_time2.png

そこで、両方向から確認する方法が双方向RNNです。

双方向シンプルRNN

順伝播

RNNを順方向と逆方向から実行します。順方向と逆方向の結果に何某かの演算を行います。
演算の方法としては、結合、合計、平均、乗算などがあります。

双方向RNN.png

RNNは、今までのシンプルRNNをそのまま利用します。$ u $の次元は、(n:バッチサイズ, t:時刻, d:説明変数の数)です。順方向は、$ u $を、逆方向は、$ u[:, ::-1] $を渡します。2つめのtの次元を逆に並び替えます。最後に、両方向の結果を演算します。
今回は、結合と合計の演算を実装します。

  • 結合
    両方向の結果$ fh_n $と$ rh_n $を結合します。numpyのconcatenate関数を利用します。結合するため、最終的な出力の次元は2倍になります。

  • 合計
    両方向の結果$ fh_n $と$ rh_n $を合計します。numpyのsum関数を利用します。

双方向シンプルRNNのプログラムです。
パラメータとして、mergeを追加しています。"concat"の場合、結合、"sum"の場合、合計とします。
重みなどの変数は、辞書型にしています。"f"が順方向、"r"が逆方向を示します。
例えば、Wx["f"]は、順方向の重み、Wx["r"]は、逆方向の重みです。同様に、h_n["f"]、hs["f"]は、順方向、h_n["r"]、hs["r"]は、逆方向です。
順方向、逆方向のRNNを呼び出し、最後に、h_n、hsの両方向を結合または合計します。
注意点です。逆方向は、時刻を逆転し実行しているため、再度、時刻の次元を反転させます。その後に、結合または合計を行います。

def BiRNN(u, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat"):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    u_dr["f"] = u
    u_dr["r"] = u[:, ::-1]
    # 初期化
    h_n, hs = {}, {}
    # 順方向
    h_n["f"], hs["f"] = RNN(u_dr["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
    # 逆方向
    h_n["r"], hs["r"] = RNN(u_dr["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
    # h_n - (n, h)
    # hs - (n, t, h)
    if merge == "concat":
        # 結合
        h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
        hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
    elif merge == "sum":
        # 合計
        h_n_merge = h_n["f"] + h_n["r"]
        hs_merge = hs["f"] + hs["r"][:, ::-1]
    return h_n_merge, hs_merge, h_n, hs

逆伝播

逆伝播もシンプルRNNと同様に行います。結合、合計部分について補足します。

  • 結合
    順方向、逆方向それぞれそのまま後ろに流しているため、後ろからの勾配をそのまま流しています。

  • 合計
    足し算の勾配は、1でした。それぞれ後ろからの勾配をそのまま流します。

注意点です。順伝播と同様に、逆方向は時刻の次元を反転させます。
duの方向は、分岐のため、順伝播、逆伝播の両方の勾配を加えます。ここでも時刻の次元の反転を忘れないようにします。

def BiRNN_back(dh_n_merge, dhs_merge, u, hs, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat"):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    u_dr["f"] = u
    u_dr["r"] = u[:, ::-1]
    h = Wh["f"].shape[0]
    # dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
    # dhs - (n, t, h)
    dh_n, dhs = {}, {}
    if merge == "concat":
        # 結合
        dh_n["f"] = dh_n_merge[:, :h]
        dh_n["r"] = dh_n_merge[:, h:]
        dhs["f"] = dhs_merge[:, :, :h]
        dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
    elif merge == "sum":
        # 合計
        dh_n["f"] = dh_n_merge
        dh_n["r"] = dh_n_merge
        dhs["f"] = dhs_merge
        dhs["r"] = dhs_merge[:, ::-1]
    # 初期化
    du = {}
    dWx, dWh, db = {}, {}, {}
    dh_0 = {}
    # 順方向
    du["f"], dWx["f"], dWh["f"], db["f"], dh_0["f"] = RNN_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
    # 逆方向
    du["r"], dWx["r"], dWh["r"], db["r"], dh_0["r"] = RNN_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
    # du - (n, t, d)
    du_merge = du["f"] + du["r"][:, ::-1]
    return du_merge, dWx, dWh, db, dh_0

多階層

双方向RNNを2階層にした場合を考えます。
以下の図のように、一層ごと出力の演算(結合または合計)を行うことになります。(図は時刻をまとめる形で簡略化して描いています。)この方法もあるかもしれませんが、順方向、逆方向それぞれを独立に多階層化し最後に演算を行うことにします。
多階層双方向RNN.png
以下の構造にも対応できるように実装を変更していきます。
多階層双方向RNN2.png
この図の1層目は、順方向、逆方向それぞれ独立に返却します。1層目は、出力時の演算は行いません。
出力方法(merge)として"dictionary"を追加します。"dictionary"の場合は、順方向、逆方向それぞれを辞書に格納します。
2層目の入力は辞書になります。そのためinput_dictionaryパラメータを追加します。Trueの場合、辞書を入力として受け取ります。
逆伝播時にも同様の変更を行います。

def BiRNN(u, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    # 初期化
    h_n, hs = {}, {}
    # 順方向
    h_n["f"], hs["f"] = RNN(u_dr["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
    # 逆方向
    h_n["r"], hs["r"] = RNN(u_dr["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
    # h_n - (n, h)
    # hs - (n, t, h)
    if merge == "concat":
        # 結合
        h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
        hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
    elif merge == "sum":
        # 合計
        h_n_merge = h_n["f"] + h_n["r"]
        hs_merge = hs["f"] + hs["r"][:, ::-1]
    elif merge == "dictionary":
        # 辞書
        h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
        hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
    return h_n_merge, hs_merge, h_n, hs

def BiRNN_back(dh_n_merge, dhs_merge, u, hs, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    h = Wh["f"].shape[0]
    # dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
    # dhs - (n, t, h)
    dh_n, dhs = {}, {}
    if merge == "concat":
        # 結合
        dh_n["f"] = dh_n_merge[:, :h]
        dh_n["r"] = dh_n_merge[:, h:]
        dhs["f"] = dhs_merge[:, :, :h]
        dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
    elif merge == "sum":
        # 合計
        dh_n["f"] = dh_n_merge
        dh_n["r"] = dh_n_merge
        dhs["f"] = dhs_merge
        dhs["r"] = dhs_merge[:, ::-1]
    elif merge == "dictionary":
        # 辞書
        dh_n["f"] = dh_n_merge["f"]
        dh_n["r"] = dh_n_merge["r"]
        dhs["f"] = dhs_merge["f"]
        dhs["r"] = dhs_merge["r"][:, ::-1]
    # 初期化
    du = {}
    dWx, dWh, db = {}, {}, {}
    dh_0 = {}
    # 順方向
    du["f"], dWx["f"], dWh["f"], db["f"], dh_0["f"] = RNN_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
    # 逆方向
    du["r"], dWx["r"], dWh["r"], db["r"], dh_0["r"] = RNN_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
    # du - (n, t, d)
    if not input_dictionary:
        du_merge = du["f"] + du["r"][:, ::-1]
    else:
        du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
    return du_merge, dWx, dWh, db, dh_0

フレームワーク対応

「ディープラーニングを実装から学ぶ(8)実装変更」に双方向RNNを追加します。

初期化

シンプルRNNと基本同じです。違いは、順方向、逆方向それぞれの重さ、バイアスを辞書に格納することです。
もう1点、演算が"concat"の場合は、各出力を結合するため出力の次元が2倍になります。
$ Wh $の初期値は、orthogonalに変更しました。

def BiRNN_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
    t, d = d_prev
    h_next = h
    # concatの場合、出力は2倍
    if merge == "concat":
        h_next = h * 2
    d_next = (t, h_next)
    if not params.get("return_hs"):
        d_next = h_next
    Wx, Wh, b = {}, {}, {}
    for dr in ["f", "r"]:
        Wx[dr] = Wx_init_func(d, h, **Wx_init_params)
        Wh[dr] = Wh_init_func(h, h, **Wh_init_params)
        b[dr]  = bias_init_func(h, **bias_init_params)
    return d_next, {"Wx":Wx, "Wh":Wh, "b":b}

def BiRNN_init_optimizer():
    sWx, sWh, sb = {}, {}, {}
    for dr in ["f", "r"]:
        sWx[dr] = {}
        sWh[dr] = {}
        sb[dr] = {}
    return {"sWx":sWx, "sWh":sWh, "sb":sb}

順伝播

順伝播を実行します。基本的に各変数を順方向、逆方向の辞書に対応するのみです。

def BiRNN_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                h_0[dr] = np.zeros((u.shape[0], weights["Wh"][dr].shape[0]))
            else:
                h_0[dr] = np.zeros((u[dr].shape[0], weights["Wh"][dr].shape[0]))
    # RNN
    h_n_merge, hs_merge, h_n, hs = func(u, h_0, weights["Wx"], weights["Wh"], weights["b"], propagation_func, merge, input_dictionary)
    # RNN最下層以外    
    if params.get("return_hs"):
        z = hs_merge
    # RNN最下層
    else:
        z = h_n_merge
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        for dr in ["f", "r"]:
            weight_decay_r += weight_decay["func"](weights["Wx"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wh"][dr], **weight_decay["params"])
    return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs}, weight_decay_r

逆伝播

逆伝播を実行します。基本的に各変数を順方向、逆方向の辞書に対応するのみです。

def BiRNN_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):
    # RNN最下層以外
    if return_hs:
        if merge != "dictionary":
            dh_n_merge = np.zeros_like(us["h_n_merge"])
        else:
            dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
        dhs_merge = dz
    # RNN最下層
    else:
        dh_n_merge = dz
        if merge != "dictionary":
            dhs_merge = np.zeros_like(us["hs_merge"])
        else:
            dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
    # RNNの勾配計算
    du, dWx, dWh, db, dh_0 = back_func(dh_n_merge, dhs_merge, us["u"], us["hs"], us["h_0"], weights["Wx"], weights["Wh"], weights["b"], propagation_func, merge, input_dictionary)
    # 重み減衰対応
    if weight_decay is not None:
        for dr in ["f", "r"]:
            dWx[dr] += weight_decay["back_func"](weights["Wx"][dr], **weight_decay["params"])
            dWh[dr] += weight_decay["back_func"](weights["Wh"][dr], **weight_decay["params"])
    return {"du":du, "dWx":dWx, "dWh":dWh, "db":db}

重みの更新

順方向、逆方向それぞれの重み、バイアスを更新します。

def BiRNN_update_weight(func, du, weights, optimizer_stats, **params):
    for dr in ["f", "r"]:
        weights["Wx"][dr], optimizer_stats["sWx"][dr] = func(weights["Wx"][dr], du["dWx"][dr], **params, **optimizer_stats["sWx"][dr])
        weights["Wh"][dr], optimizer_stats["sWh"][dr] = func(weights["Wh"][dr], du["dWh"][dr], **params, **optimizer_stats["sWh"][dr])
        weights["b"][dr],  optimizer_stats["sb"][dr]  = func(weights["b"][dr],  du["db"][dr],  **params, **optimizer_stats["sb"][dr])
    return weights, optimizer_stats

実行例

「ディープラーニングを実装から学ぶ(8)実装変更」の「参考」-「プログラム全体」のプログラムを事前に実行しておきます。
MNISTのデータを読み込みます。(MNISTのデータファイルが、c:\mnistに格納されている例)

x_train, t_train, x_test, t_test = load_mnist('c:\\mnist\\')
data_normalizer = create_data_normalizer(min_max)
nx_train, data_normalizer_stats = train_data_normalize(data_normalizer, x_train)
nx_test                         = test_data_normalize(data_normalizer, x_test)

時系列データとなるようにreshapeします。

nx_train = nx_train.reshape((nx_train.shape[0], 28, 28))
nx_test  = nx_test.reshape((nx_test.shape[0], 28, 28))

モデルを定義します。出力次元数は100にします。
双方向RNNの1層です。演算は、合計("sum")とします。

model = create_model((28,28))
model = add_layer(model, "BiRNN1", BiRNN, 100, merge="sum")
model = add_layer(model, "affine", affine, 10)
model = set_output(model, softmax)
model = set_error(model, cross_entropy_error)
optimizer = create_optimizer(SGD, lr=0.01)

30エポック実行します。

epoch = 30
batch_size = 100
np.random.seed(10)
model, optimizer, learn_info = learn(model, nx_train, t_train, nx_test, t_test, batch_size=batch_size, epoch=epoch, optimizer=optimizer)

結果です。

input - 0 (28, 28)
BiRNN1 BiRNN (28, 28) 100
affine affine 100 10
output softmax 10
error cross_entropy_error
0 0.08276666666666667 2.4694932161948393 0.0807 2.467711539524275 
1 0.7976333333333333 0.7224299245059749 0.9011 0.36007855521236326 
2 0.9129 0.3071644615683738 0.9297 0.24448251744407254 
3 0.933 0.23339284779792896 0.9418 0.20460460055846733 
4 0.9428833333333333 0.19330278050141367 0.9452 0.18235088618707465 
5 0.9502833333333334 0.16817265239467252 0.9556 0.14892486442113187 
6 0.9550666666666666 0.1496513755834529 0.9563 0.14608596831733958 
7 0.9593833333333334 0.1357794529971745 0.9584 0.13615785606170594 
8 0.9623666666666667 0.12634386000411943 0.961 0.13028779104323354 
9 0.96525 0.11460733561373888 0.9657 0.11348966105739382 
10 0.9682333333333333 0.1079471845883365 0.9668 0.10858853202882413 
11 0.9692333333333333 0.10160679157369244 0.9691 0.10127266935092118 
12 0.97065 0.09646661669902329 0.9707 0.09741699781003078 
13 0.9738833333333333 0.08874902808825438 0.9709 0.0974751791352071 
14 0.97465 0.08513319550104065 0.9736 0.08820080097152737 
15 0.9753 0.08127818526156802 0.9723 0.08705366351988436 
16 0.9768333333333333 0.07730725580142364 0.9758 0.08434469601550207 
17 0.9783 0.07244720897536845 0.9762 0.08135833169989594 
18 0.9787333333333333 0.07012015726493016 0.9693 0.10346379775865691 
19 0.9800333333333333 0.06777993291484356 0.9765 0.07907273898044402 
20 0.98015 0.0648363240542574 0.9758 0.07887806162399832 
21 0.98125 0.06262620142131756 0.9773 0.07744372451315322 
22 0.9823 0.05932344435748826 0.9678 0.10134595089422335 
23 0.9833 0.057475974001274625 0.9749 0.08343724319233303 
24 0.9832333333333333 0.05571428411124262 0.9761 0.07660213665776289 
25 0.9844 0.05324665768972804 0.9781 0.0747334290192525 
26 0.9855166666666667 0.05124096084976474 0.9789 0.06781155342892754 
27 0.9848833333333333 0.050314958583714584 0.978 0.07235791581025655 
28 0.9863166666666666 0.04735377729986719 0.9787 0.06785586351774157 
29 0.9861333333333333 0.04710562378222973 0.9779 0.07280146686467145 
30 0.9862833333333333 0.045595901126329566 0.9793 0.06660518314597848 
所要時間 = 10 分 27 秒

所要時間は、当然ですが、RNNの約2倍になりました。

双方向RNNと通常のRNNの結果を表にしました。双方向RNNの方が精度が向上しております。

種別 学習正解 テスト正解 テスト最高
双方向RNN(sum) 98.63 97.93 97.93
双方向RNN(concat) 98.65 97.80 97.80
RNN(順方向) 98.27 97.34 97.54
RNN(逆方向) 98.30 97.46 97.80

双方向LSTM

双方向LSTMも双方向シンプルRNN同様に、重みなどを辞書化し、順方向、逆方向のデータを保持します。同様の対応を行います。

def BiLSTM(u, h_0, c_0, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    # 初期化
    h_n, hs, c_n, cs = {}, {}, {}, {}
    fs, Is, gs, os = {}, {}, {}, {}
    # 順方向
    h_n["f"], hs["f"], c_n["f"], cs["f"], fs["f"], Is["f"], gs["f"], os["f"] = LSTM(u_dr["f"], h_0["f"], c_0["f"], Wxf["f"], Whf["f"], bf["f"], Wxi["f"], Whi["f"], bi["f"], Wxg["f"], Whg["f"], bg["f"], Wxo["f"], Who["f"], bo["f"], propagation_func)
    # 逆方向
    h_n["r"], hs["r"], c_n["r"], cs["r"], fs["r"], Is["r"], gs["r"], os["r"] = LSTM(u_dr["r"], h_0["r"], c_0["r"], Wxf["r"], Whf["r"], bf["r"], Wxi["r"], Whi["r"], bi["r"], Wxg["r"], Whg["r"], bg["r"], Wxo["r"], Who["r"], bo["r"], propagation_func)
    # h_n, c_n - (n, h)
    # hs, cs - (n, t, h)
    if merge == "concat":
        # 結合
        h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
        hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
    elif merge == "sum":
        # 合計
        h_n_merge = h_n["f"] + h_n["r"]
        hs_merge = hs["f"] + hs["r"][:, ::-1]
    elif merge == "dictionary":
        # 辞書
        h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
        hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
    return h_n_merge, hs_merge, h_n, hs, c_n, cs, fs, Is, gs, os

def BiLSTM_back(dh_n_merge, dhs_merge, u, hs, h_0, cs, c_0, fs, Is, gs, os, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    h = Whf["f"].shape[0]
    # dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
    # dhs - (n, t, h)
    dh_n, dhs = {}, {}
    if merge == "concat":
        # 結合
        dh_n["f"] = dh_n_merge[:, :h]
        dh_n["r"] = dh_n_merge[:, h:]
        dhs["f"] = dhs_merge[:, :, :h]
        dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
    elif merge == "sum":
        # 合計
        dh_n["f"] = dh_n_merge
        dh_n["r"] = dh_n_merge
        dhs["f"] = dhs_merge
        dhs["r"] = dhs_merge[:, ::-1]
    elif merge == "dictionary":
        # 辞書
        dh_n["f"] = dh_n_merge["f"]
        dh_n["r"] = dh_n_merge["r"]
        dhs["f"] = dhs_merge["f"]
        dhs["r"] = dhs_merge["r"][:, ::-1]
    # 初期化
    du = {}
    dWxf, dWhf, dbf = {}, {}, {}
    dWxi, dWhi, dbi = {}, {}, {}
    dWxg, dWhg, dbg = {}, {}, {}
    dWxo, dWho, dbo = {}, {}, {}
    dh_0, dc_0 = {}, {}
    # 順方向
    du["f"], dWxf["f"], dWhf["f"], dbf["f"], dWxi["f"], dWhi["f"], dbi["f"], dWxg["f"], dWhg["f"], dbg["f"], dWxo["f"], dWho["f"], dbo["f"], dh_0["f"], dc_0["f"] = \
        LSTM_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], cs["f"], c_0["f"], fs["f"], Is["f"], gs["f"], os["f"], Wxf["f"], Whf["f"], bf["f"], Wxi["f"], Whi["f"], bi["f"], Wxg["f"], Whg["f"], bg["f"], Wxo["f"], Who["f"], bo["f"], propagation_func)
    # 逆方向
    du["r"], dWxf["r"], dWhf["r"], dbf["r"], dWxi["r"], dWhi["r"], dbi["r"], dWxg["r"], dWhg["r"], dbg["r"], dWxo["r"], dWho["r"], dbo["r"], dh_0["r"], dc_0["r"] = \
        LSTM_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], cs["r"], c_0["r"], fs["r"], Is["r"], gs["r"], os["r"], Wxf["r"], Whf["r"], bf["r"], Wxi["r"], Whi["r"], bi["r"], Wxg["r"], Whg["r"], bg["r"], Wxo["r"], Who["r"], bo["r"], propagation_func)
    # du - (n, t, d)
    if not input_dictionary:
        du_merge = du["f"] + du["r"][:, ::-1]
    else:
        du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
    return du_merge, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0
def BiLSTM_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
    t, d = d_prev
    h_next = h
    # concatの場合、出力は2倍
    if merge == "concat":
        h_next = h * 2
    d_next = (t, h_next)
    if not params.get("return_hs"):
        d_next = h_next
    Wxf, Whf, bf = {}, {}, {}
    Wxi, Whi, bi = {}, {}, {}
    Wxg, Whg, bg = {}, {}, {}
    Wxo, Who, bo = {}, {}, {}
    for dr in ["f", "r"]:
        Wxf[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whf[dr] = Wh_init_func(h, h, **Wh_init_params)
        bf[dr]  = bias_init_func(h, **bias_init_params)
        Wxi[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whi[dr] = Wh_init_func(h, h, **Wh_init_params)
        bi[dr]  = bias_init_func(h, **bias_init_params)
        Wxg[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whg[dr] = Wh_init_func(h, h, **Wh_init_params)
        bg[dr]  = bias_init_func(h, **bias_init_params)
        Wxo[dr] = Wx_init_func(d, h, **Wx_init_params)
        Who[dr] = Wh_init_func(h, h, **Wh_init_params)
        bo[dr]  = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxf":Wxf, "Whf":Whf, "bf":bf, "Wxi":Wxi, "Whi":Whi, "bi":bi, "Wxg":Wxg, "Whg":Whg, "bg":bg, "Wxo":Wxo, "Who":Who, "bo":bo}

def BiLSTM_init_optimizer():
    sWxf, sWhf, sbf = {}, {}, {}
    sWxi, sWhi, sbi = {}, {}, {}
    sWxg, sWhg, sbg = {}, {}, {}
    sWxo, sWho, sbo = {}, {}, {}
    for dr in ["f", "r"]:
        sWxf[dr] = {}
        sWhf[dr] = {}
        sbf[dr]  = {}
        sWxi[dr] = {}
        sWhi[dr] = {}
        sbi[dr]  = {}
        sWxg[dr] = {}
        sWhg[dr] = {}
        sbg[dr]  = {}
        sWxo[dr] = {}
        sWho[dr] = {}
        sbo[dr]  = {}
    return {"sWxf":sWxf, "sWhf":sWhf, "sbf":sbf,
             "sWxi":sWxi, "sWhi":sWhi, "sbi":sbi,
             "sWxg":sWxg, "sWhg":sWhg, "sbg":sbg,
             "sWxo":sWxo, "sWho":sWho, "sbo":sbo}

def BiLSTM_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                h_0[dr] = np.zeros((u.shape[0], weights["Whf"][dr].shape[0]))
            else:
                h_0[dr] = np.zeros((u[dr].shape[0], weights["Whf"][dr].shape[0]))
    c_0 = params.get("c_0")
    if c_0 is None:
        c_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                c_0[dr] = np.zeros((u.shape[0], weights["Whf"][dr].shape[0]))
            else:
                c_0[dr] = np.zeros((u[dr].shape[0], weights["Whf"][dr].shape[0]))
    # LSTM
    h_n_merge, hs_merge, h_n, hs, c_n, cs, fs, Is, gs, os = func(u, h_0, c_0,
                                                                 weights["Wxf"], weights["Whf"], weights["bf"],
                                                                 weights["Wxi"], weights["Whi"], weights["bi"],
                                                                 weights["Wxg"], weights["Whg"], weights["bg"],
                                                                 weights["Wxo"], weights["Who"], weights["bo"],
                                                                 propagation_func, merge, input_dictionary)
    # LSTM最下層以外 
    if params.get("return_hs"):
        z = hs_merge
    # LSTM最下層
    else:
        z = h_n_merge
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        for dr in ["f", "r"]:
            weight_decay_r += weight_decay["func"](weights["Wxf"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whf"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxi"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whi"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxg"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whg"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxo"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Who"][dr], **weight_decay["params"])

    return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs, "c_0":c_0, "c_n":c_n, "cs":cs,
             "fs":fs, "Is":Is, "gs":gs, "os":os}, weight_decay_r

def BiLSTM_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):    
    # LSTM最下層以外
    if return_hs:
        if merge != "dictionary":
            dh_n_merge = np.zeros_like(us["h_n_merge"])
        else:
            dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
        dhs_merge = dz
    # LSTM最下層
    else:
        dh_n_merge = dz
        if merge != "dictionary":
            dhs_merge = np.zeros_like(us["hs_merge"])
        else:
            dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
    # LSTMの勾配計算
    du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0 = back_func(dh_n_merge, dhs_merge,
                                        us["u"], us["hs"], us["h_0"], us["cs"], us["c_0"],
                                        us["fs"], us["Is"], us["gs"], us["os"],
                                        weights["Wxf"], weights["Whf"], weights["bf"],
                                        weights["Wxi"], weights["Whi"], weights["bi"],
                                        weights["Wxg"], weights["Whg"], weights["bg"],
                                        weights["Wxo"], weights["Who"], weights["bo"],
                                        propagation_func, merge, input_dictionary)

    # 重み減衰対応
    if weight_decay is not None:
        for dr in ["f", "r"]:
            dWxf[dr] += weight_decay["back_func"](weights["Wxf"][dr], **weight_decay["params"])
            dWhf[dr] += weight_decay["back_func"](weights["Whf"][dr], **weight_decay["params"])
            dWxi[dr] += weight_decay["back_func"](weights["Wxi"][dr], **weight_decay["params"])
            dWhi[dr] += weight_decay["back_func"](weights["Whi"][dr], **weight_decay["params"])
            dWxg[dr] += weight_decay["back_func"](weights["Wxg"][dr], **weight_decay["params"])
            dWhg[dr] += weight_decay["back_func"](weights["Whg"][dr], **weight_decay["params"])
            dWxo[dr] += weight_decay["back_func"](weights["Wxo"][dr], **weight_decay["params"])
            dWho[dr] += weight_decay["back_func"](weights["Who"][dr], **weight_decay["params"])

    return {"du":du, "dWxf":dWxf, "dWhf":dWhf, "dbf":dbf, "dWxi":dWxi, "dWhi":dWhi, "dbi":dbi, "dWxg":dWxg, "dWhg":dWhg, "dbg":dbg, "dWxo":dWxo, "dWho":dWho, "dbo":dbo}

def BiLSTM_update_weight(func, du, weights, optimizer_stats, **params):
    for dr in ["f", "r"]:
        weights["Wxf"][dr], optimizer_stats["sWxf"][dr] = func(weights["Wxf"][dr], du["dWxf"][dr], **params, **optimizer_stats["sWxf"][dr])
        weights["Whf"][dr], optimizer_stats["sWhf"][dr] = func(weights["Whf"][dr], du["dWhf"][dr], **params, **optimizer_stats["sWhf"][dr])
        weights["bf"][dr],  optimizer_stats["sbf"][dr]  = func(weights["bf"][dr],  du["dbf"][dr],  **params, **optimizer_stats["sbf"][dr])
        weights["Wxi"][dr], optimizer_stats["sWxi"][dr] = func(weights["Wxi"][dr], du["dWxi"][dr], **params, **optimizer_stats["sWxi"][dr])
        weights["Whi"][dr], optimizer_stats["sWhi"][dr] = func(weights["Whi"][dr], du["dWhi"][dr], **params, **optimizer_stats["sWhi"][dr])
        weights["bi"][dr],  optimizer_stats["sbi"][dr]  = func(weights["bi"][dr],  du["dbi"][dr],  **params, **optimizer_stats["sbi"][dr])
        weights["Wxg"][dr], optimizer_stats["sWxg"][dr] = func(weights["Wxg"][dr], du["dWxg"][dr], **params, **optimizer_stats["sWxg"][dr])
        weights["Whg"][dr], optimizer_stats["sWhg"][dr] = func(weights["Whg"][dr], du["dWhg"][dr], **params, **optimizer_stats["sWhg"][dr])
        weights["bg"][dr],  optimizer_stats["sbg"][dr]  = func(weights["bg"][dr],  du["dbg"][dr],  **params, **optimizer_stats["sbg"][dr])
        weights["Wxo"][dr], optimizer_stats["sWxo"][dr] = func(weights["Wxo"][dr], du["dWxo"][dr], **params, **optimizer_stats["sWxo"][dr])
        weights["Who"][dr], optimizer_stats["sWho"][dr] = func(weights["Who"][dr], du["dWho"][dr], **params, **optimizer_stats["sWho"][dr])
        weights["bo"][dr],  optimizer_stats["sbo"][dr]  = func(weights["bo"][dr],  du["dbo"][dr],  **params, **optimizer_stats["sbo"][dr])
    return weights, optimizer_stats

双方向GRU

双方向GRUも双方向シンプルRNN、双方向LSTM同様に、重みなどを辞書化し、順方向、逆方向のデータを保持します。同様の対応を行います。

def BiGRU(u, h_0, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    # 初期化
    h_n, hs = {}, {}
    zs, rs, ns = {}, {}, {}
    # 順方向
    h_n["f"], hs["f"], zs["f"], rs["f"], ns["f"] = GRU(u_dr["f"], h_0["f"], Wxz["f"], Whz["f"], bz["f"], Wxr["f"], Whr["f"], br["f"], Wxn["f"], Whn["f"], bn["f"], propagation_func)
    # 逆方向
    h_n["r"], hs["r"], zs["r"], rs["r"], ns["r"] = GRU(u_dr["r"], h_0["r"], Wxz["r"], Whz["r"], bz["r"], Wxr["r"], Whr["r"], br["r"], Wxn["r"], Whn["r"], bn["r"], propagation_func)
    # h_n - (n, h)
    # hs - (n, t, h)
    if merge == "concat":
        # 結合
        h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
        hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
    elif merge == "sum":
        # 合計
        h_n_merge = h_n["f"] + h_n["r"]
        hs_merge = hs["f"] + hs["r"][:, ::-1]
    elif merge == "dictionary":
        # 辞書
        h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
        hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
    return h_n_merge, hs_merge, h_n, hs, zs, rs, ns

def BiGRU_back(dh_n_merge, dhs_merge, u, hs, h_0, zs, rs, ns, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    h = Whz["f"].shape[0]
    # dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
    # dhs - (n, t, h)
    dh_n, dhs = {}, {}
    if merge == "concat":
        # 結合
        dh_n["f"] = dh_n_merge[:, :h]
        dh_n["r"] = dh_n_merge[:, h:]
        dhs["f"] = dhs_merge[:, :, :h]
        dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
    elif merge == "sum":
        # 合計
        dh_n["f"] = dh_n_merge
        dh_n["r"] = dh_n_merge
        dhs["f"] = dhs_merge
        dhs["r"] = dhs_merge[:, ::-1]
    elif merge == "dictionary":
        # 辞書
        dh_n["f"] = dh_n_merge["f"]
        dh_n["r"] = dh_n_merge["r"]
        dhs["f"] = dhs_merge["f"]
        dhs["r"] = dhs_merge["r"][:, ::-1]
    # 初期化
    du = {}
    dWxz, dWhz, dbz = {}, {}, {}
    dWxr, dWhr, dbr = {}, {}, {}
    dWxn, dWhn, dbn = {}, {}, {}
    dh_0 = {}
    # 順方向
    du["f"], dWxz["f"], dWhz["f"], dbz["f"], dWxr["f"], dWhr["f"], dbr["f"], dWxn["f"], dWhn["f"], dbn["f"], dh_0["f"] = \
        GRU_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], zs["f"], rs["f"], ns["f"], Wxz["f"], Whz["f"], bz["f"], Wxr["f"], Whr["f"], br["f"], Wxn["f"], Whn["f"], bn["f"], propagation_func)
    # 逆方向
    du["r"], dWxz["r"], dWhz["r"], dbz["r"], dWxr["r"], dWhr["r"], dbr["r"], dWxn["r"], dWhn["r"], dbn["r"], dh_0["r"] = \
        GRU_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], zs["r"], rs["r"], ns["r"], Wxz["r"], Whz["r"], bz["r"], Wxr["r"], Whr["r"], br["r"], Wxn["r"], Whn["r"], bn["r"], propagation_func)
    # du - (n, t, d)
    if not input_dictionary:
        du_merge = du["f"] + du["r"][:, ::-1]
    else:
        du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
    return du_merge, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0
def BiGRU_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
    t, d = d_prev
    h_next = h
    # concatの場合、出力は2倍
    if merge == "concat":
        h_next = h * 2
    d_next = (t, h_next)
    if not params.get("return_hs"):
        d_next = h_next
    Wxz, Whz, bz = {}, {}, {}
    Wxr, Whr, br = {}, {}, {}
    Wxn, Whn, bn = {}, {}, {}
    for dr in ["f", "r"]:
        Wxz[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whz[dr] = Wh_init_func(h, h, **Wh_init_params)
        bz[dr]  = bias_init_func(h, **bias_init_params)
        Wxr[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whr[dr] = Wh_init_func(h, h, **Wh_init_params)
        br[dr]  = bias_init_func(h, **bias_init_params)
        Wxn[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whn[dr] = Wh_init_func(h, h, **Wh_init_params)
        bn[dr]  = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxz":Wxz, "Whz":Whz, "bz":bz, "Wxr":Wxr, "Whr":Whr, "br":br, "Wxn":Wxn, "Whn":Whn, "bn":bn}

def BiGRU_init_optimizer():
    sWxz, sWhz, sbz = {}, {}, {}
    sWxr, sWhr, sbr = {}, {}, {}
    sWxn, sWhn, sbn = {}, {}, {}
    for dr in ["f", "r"]:
        sWxz[dr] = {}
        sWhz[dr] = {}
        sbz[dr]  = {}
        sWxr[dr] = {}
        sWhr[dr] = {}
        sbr[dr]  = {}
        sWxn[dr] = {}
        sWhn[dr] = {}
        sbn[dr]  = {}
    return {"sWxz":sWxz, "sWhz":sWhz, "sbz":sbz,
             "sWxr":sWxr, "sWhr":sWhr, "sbr":sbr,
             "sWxn":sWxn, "sWhn":sWhn, "sbn":sbn}

def BiGRU_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                h_0[dr] = np.zeros((u.shape[0], weights["Whz"][dr].shape[0]))
            else:
                h_0[dr] = np.zeros((u[dr].shape[0], weights["Whz"][dr].shape[0]))
    c_0 = params.get("c_0")
    if c_0 is None:
        c_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                c_0[dr] = np.zeros((u.shape[0], weights["Whz"][dr].shape[0]))
            else:
                c_0[dr] = np.zeros((u[dr].shape[0], weights["Whz"][dr].shape[0]))
    # GRU
    h_n_merge, hs_merge, h_n, hs, zs, rs, ns = func(u, h_0,
                                                    weights["Wxz"], weights["Whz"], weights["bz"],
                                                    weights["Wxr"], weights["Whr"], weights["br"],
                                                    weights["Wxn"], weights["Whn"], weights["bn"],
                                                    propagation_func, merge, input_dictionary)
    # GRU最下層以外 
    if params.get("return_hs"):
        z = hs_merge
    # GRU最下層
    else:
        z = h_n_merge
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        for dr in ["f", "r"]:
            weight_decay_r += weight_decay["func"](weights["Wxz"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whz"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxr"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whr"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxn"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whn"][dr], **weight_decay["params"])

    return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs, "zs":zs, "rs":rs, "ns":ns}, weight_decay_r

def BiGRU_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):    
    # GRU最下層以外
    if return_hs:
        if merge != "dictionary":
            dh_n_merge = np.zeros_like(us["h_n_merge"])
        else:
            dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
        dhs_merge = dz
    # GRU最下層
    else:
        dh_n_merge = dz
        if merge != "dictionary":
            dhs_merge = np.zeros_like(us["hs_merge"])
        else:
            dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
    # GRUの勾配計算
    du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0 = back_func(dh_n_merge, dhs_merge,
                                        us["u"], us["hs"], us["h_0"],
                                        us["zs"], us["rs"], us["ns"],
                                        weights["Wxz"], weights["Whz"], weights["bz"],
                                        weights["Wxr"], weights["Whr"], weights["br"],
                                        weights["Wxn"], weights["Whn"], weights["bn"],
                                        propagation_func, merge, input_dictionary)

    # 重み減衰対応
    if weight_decay is not None:
        for dr in ["f", "r"]:
            dWxz[dr] += weight_decay["back_func"](weights["Wxz"][dr], **weight_decay["params"])
            dWhz[dr] += weight_decay["back_func"](weights["Whz"][dr], **weight_decay["params"])
            dWxr[dr] += weight_decay["back_func"](weights["Wxr"][dr], **weight_decay["params"])
            dWhr[dr] += weight_decay["back_func"](weights["Whr"][dr], **weight_decay["params"])
            dWxn[dr] += weight_decay["back_func"](weights["Wxn"][dr], **weight_decay["params"])
            dWhn[dr] += weight_decay["back_func"](weights["Whn"][dr], **weight_decay["params"])

    return {"du":du, "dWxz":dWxz, "dWhz":dWhz, "dbz":dbz, "dWxr":dWxr, "dWhr":dWhr, "dbr":dbr, "dWxn":dWxn, "dWhn":dWhn, "dbn":dbn}

def BiGRU_update_weight(func, du, weights, optimizer_stats, **params):
    for dr in ["f", "r"]:
        weights["Wxz"][dr], optimizer_stats["sWxz"][dr] = func(weights["Wxz"][dr], du["dWxz"][dr], **params, **optimizer_stats["sWxz"][dr])
        weights["Whz"][dr], optimizer_stats["sWhz"][dr] = func(weights["Whz"][dr], du["dWhz"][dr], **params, **optimizer_stats["sWhz"][dr])
        weights["bz"][dr],  optimizer_stats["sbz"][dr]  = func(weights["bz"][dr],  du["dbz"][dr],  **params, **optimizer_stats["sbz"][dr])
        weights["Wxr"][dr], optimizer_stats["sWxr"][dr] = func(weights["Wxr"][dr], du["dWxr"][dr], **params, **optimizer_stats["sWxr"][dr])
        weights["Whr"][dr], optimizer_stats["sWhr"][dr] = func(weights["Whr"][dr], du["dWhr"][dr], **params, **optimizer_stats["sWhr"][dr])
        weights["br"][dr],  optimizer_stats["sbr"][dr]  = func(weights["br"][dr],  du["dbr"][dr],  **params, **optimizer_stats["sbr"][dr])
        weights["Wxn"][dr], optimizer_stats["sWxn"][dr] = func(weights["Wxn"][dr], du["dWxn"][dr], **params, **optimizer_stats["sWxn"][dr])
        weights["Whn"][dr], optimizer_stats["sWhn"][dr] = func(weights["Whn"][dr], du["dWhn"][dr], **params, **optimizer_stats["sWhn"][dr])
        weights["bn"][dr],  optimizer_stats["sbn"][dr]  = func(weights["bn"][dr],  du["dbn"][dr],  **params, **optimizer_stats["sbn"][dr])
    return weights, optimizer_stats

参考

双方向RNN・orthogonal(重みの初期値)全体のプログラムです。
また、RNNの重みの初期をorthogonalに変更したためRNNについても記載します。

関数仕様

# 層追加関数
model = add_layer(model, name, func, d=None, **kwargs)
# 引数
#  model : モデル
#  name  : レイヤの名前
#  func  : 中間層の関数
#           affine,sigmoid,tanh,relu,leaky_relu,prelu,rrelu,relun,srelu,elu,maxout,
#           identity,softplus,softsign,step,dropout,batch_normalization,
#           convolution2d,max_pooling2d,average_pooling2d,flatten2d,
#           RNN,LSTM,GRU,BiRNN, BiLSTM,BiGRU
#  d     : ノード数
#           affine,maxout,convolution2d,RNN,LSTM,GRU,BiRNN, BiLSTM,BiGRUの場合指定
#           convolution2dの場合は、フィルタ数
#           RNN,LSTM,GRU,BiRNN, BiLSTM,BiGRUの場合は、出力次元数
#  kwargs: 中間層の関数のパラメータ
#           affine - weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#                     weight_init_func - lecun_normal,lecun_uniform,glorot_normal,glorot_uniform,he_normal,he_uniform,normal_w,uniform_w,zeros_w,ones_w
#                     weight_init_params
#                       normal_w - mean=0, var=1
#                       uniform_w - min=0, max=1
#                     bias_init_func - normal_b,uniform_b,zeros_b,ones_b
#                     bias_init_params
#                       normal_b - mean=0, var=1
#                       uniform_b - min=0, max=1
#           leaky_relu - alpha
#           rrelu - min=0.0, max=0.1
#           relun - n
#           srelu - alpha
#           elu - alpha
#           maxout - unit=1, weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#           dropout - dropout_ratio=0.9
#           batch_normalization - batch_norm_node=True, use_gamma_beta=True, use_train_stats=True, alpha=0.1
#           convolution2d - padding=0, strides=1, weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#           max_pooling2d - pool_size=2, padding=0, strides=None
#           average_pooling2d - pool_size=2, padding=0, strides=None
#           RNN, LSTM, GRU - propagation_func=tanh, return_hs=False, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#                            propagation_func - tanh,sigmoid,relu
#                            Wx_init_func, Wh_init_funcは、affineのweight_init_funcに同じ。Wh_init_funcは、orthogonalも指定可能
#           BiRNN, BiLSTM, BiGRU - propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}
#                                  propagation_func - tanh,sigmoid,relu
#                                  merge - "concat", "sum", "dictionary"
#                                  Wx_init_func, Wh_init_funcは、affineのweight_init_funcに同じ。Wh_init_funcは、orthogonalも指定可能
# 戻り値
#  モデル

プログラム

def orthogonal(d_1, d):
    rand = np.random.normal(0., 1., (d_1, d))
    return np.linalg.svd(rand)[0]
def RNN(u, h_0, Wx, Wh, b, propagation_func=tanh):
    # u - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    h_t_1 = h_0
    hs = np.zeros((ut, hn, hh))
    for t in range(ut):
        # RNN
        u_t = u[t]
        h_t = propagation_func(np.dot(h_t_1, Wh) + np.dot(u_t, Wx) + b)
        h_t_1 = h_t
        # hの値保持
        hs[t] = h_t
    # 後設定
    h_n = h_t
    hs = hs.transpose(1,0,2)
    return h_n, hs

def RNN_back(dh_n, dhs, u, hs, h_0, Wx, Wh, b, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # Wx - (d, h)
    # Wh - (h, h)
    # b  - h
    # 初期設定
    dh_t  = dh_n
    du  = np.zeros_like(u)
    dWx = np.zeros_like(Wx)
    dWh = np.zeros_like(Wh)
    db  = np.zeros_like(b)
    for t in reversed(range(ut)):
        # RNN勾配計算
        u_t = u[t]
        h_t = hs[t]
        if t == 0:
            h_t_1 = h_0
        else:
            h_t_1 = hs[t-1]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]
        # tanhの勾配
        dp = propagation_back_func(dh_t, None, h_t)
        # 各勾配
        # h =>  (n, h)
        db  += np.sum(dp, axis=0)
        # (h, h) => (h, n) @ (n, h)
        dWh += np.dot(h_t_1.T, dp)
        # (d, h) => (d, n) @ (n, h)
        dWx += np.dot(u_t.T, dp)
        # (n, h) => (n, h) @ (h, h)
        dh_t_1 = np.dot(dp, Wh.T)
        # (n, d) => (n, h) @ (h, d)
        du_t = np.dot(dp, Wx.T)
        dh_t = dh_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    du = du.transpose(1,0,2)
    return du, dWx, dWh, db, dh_0

def LSTM(u, h_0, c_0, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh):
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # c_0  - (n, h)
    cn, ch = c_0.shape
    # Wxf, Wxi, Wxg, Who - (d, h)
    # Whf, Whi, Whg, Who - (h, h)
    # bf,  bi,  bg,  bo  - h
    # 初期設定
    h_t_1 = h_0
    c_t_1 = c_0
    hs = np.zeros((ut, hn, hh))
    cs = np.zeros((ut, cn, ch))
    fs = np.zeros((ut, hn, hh))
    Is = np.zeros((ut, hn, hh))
    gs = np.zeros((ut, hn, hh))
    os = np.zeros((ut, hn, hh))
    for t in range(ut):
        # LSTM
        u_t = u[t]
        f_t = sigmoid(np.dot(h_t_1, Whf) + np.dot(u_t, Wxf) + bf)
        i_t = sigmoid(np.dot(h_t_1, Whi) + np.dot(u_t, Wxi) + bi)
        g_t = propagation_func(np.dot(h_t_1, Whg) + np.dot(u_t, Wxg) + bg)
        o_t = sigmoid(np.dot(h_t_1, Who) + np.dot(u_t, Wxo) + bo)

        c_t = f_t * c_t_1 + i_t * g_t
        h_t = o_t * propagation_func(c_t)
        # h,c,f,i,g,oの値保持
        hs[t] = h_t
        cs[t] = c_t
        fs[t] = f_t
        Is[t] = i_t
        gs[t] = g_t
        os[t] = o_t
        h_t_1 = h_t
        c_t_1 = c_t
    # 後設定
    h_n = h_t
    c_n = c_t
    hs = hs.transpose(1,0,2)
    cs = cs.transpose(1,0,2)
    return h_n, hs, c_n, cs, fs, Is, gs, os

def LSTM_back(dh_n, dhs, u, hs, h_0, cs, c_0, fs, Is, gs, os, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # cs - (n, t, h)
    cs = cs.transpose(1,0,2)
    # c_0 - (n, h)
    # fs, Is, gs, os - (n, t, h)
    # Wxf, Wxi, Wxg, Who - (d, h)
    # Whf, Whi, Whg, Who - (h, h)
    # bf,  bi,  bg,  bo  - h
    # 初期設定
    dh_t  = dh_n
    dc_t  = np.zeros_like(c_0)
    du  = np.zeros_like(u)
    dWxf = np.zeros_like(Wxf)
    dWhf = np.zeros_like(Whf)
    dbf  = np.zeros_like(bf)
    dWxi = np.zeros_like(Wxi)
    dWhi = np.zeros_like(Whi)
    dbi  = np.zeros_like(bi)
    dWxg = np.zeros_like(Wxg)
    dWhg = np.zeros_like(Whg)
    dbg  = np.zeros_like(bg)
    dWxo = np.zeros_like(Wxo)
    dWho = np.zeros_like(Who)
    dbo  = np.zeros_like(bo)
    for t in reversed(range(ut)):
        # LSTM勾配計算
        u_t = u[t]
        h_t = hs[t]
        c_t = cs[t]
        if t == 0:
            h_t_1 = h_0
            c_t_1 = c_0
        else:
            h_t_1 = hs[t-1]
            c_t_1 = cs[t-1]
        o_t = os[t]
        g_t = gs[t]
        i_t = Is[t]
        f_t = fs[t]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]        
        # tanhの勾配
        dp = dh_t * o_t
        p_t = propagation_func(c_t)
        dc_t = dc_t + propagation_back_func(dp, None, p_t)
        # 各勾配
        # o
        do = dh_t * p_t
        dpo = sigmoid_back(do, None, o_t)
        dbo  += np.sum(dpo, axis=0)
        dWho += np.dot(h_t_1.T, dpo) 
        dWxo += np.dot(u_t.T, dpo)
        # g
        dg = dc_t * i_t
        dpg = propagation_back_func(dg, None, g_t)
        dbg  += np.sum(dpg, axis=0)
        dWhg += np.dot(h_t_1.T, dpg) 
        dWxg += np.dot(u_t.T, dpg)
        # i
        di = dc_t * g_t
        dpi = sigmoid_back(di, None, i_t)
        dbi  += np.sum(dpi, axis=0)
        dWhi += np.dot(h_t_1.T, dpi) 
        dWxi += np.dot(u_t.T, dpi)
        # f
        df = dc_t * c_t_1
        dpf = sigmoid_back(df, None, f_t)
        dbf  += np.sum(dpf, axis=0)
        dWhf += np.dot(h_t_1.T, dpf) 
        dWxf += np.dot(u_t.T, dpf)
        # c
        dc_t_1 = dc_t * f_t
        # h
        dh_t_1 = np.dot(dpf, Whf.T) + np.dot(dpi, Whi.T) + np.dot(dpg, Whg.T) + np.dot(dpo, Who.T)
        # u
        du_t = np.dot(dpf, Wxf.T) + np.dot(dpi, Wxi.T) + np.dot(dpg, Wxg.T) + np.dot(dpo, Wxo.T)
        dh_t = dh_t_1
        dc_t = dc_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    dc_0 = dc_t_1
    du = du.transpose(1,0,2)

    return du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0

def GRU(u, h_0, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh):
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # h_0  - (n, h)
    hn, hh = h_0.shape
    # Wxz, Wxr, Wxn - (d, h)
    # Whz, Whr, Whn - (h, h)
    # bz,  br,  bn  - h
    # 初期設定
    h_t_1 = h_0
    hs = np.zeros((ut, hn, hh))
    zs = np.zeros((ut, hn, hh))
    rs = np.zeros((ut, hn, hh))
    ns = np.zeros((ut, hn, hh))
    for t in range(ut):
        # GRU
        u_t = u[t]
        z_t = sigmoid(np.dot(h_t_1, Whz) + np.dot(u_t, Wxz) + bz)
        r_t = sigmoid(np.dot(h_t_1, Whr) + np.dot(u_t, Wxr) + br)
        n_t = propagation_func(r_t * np.dot(h_t_1, Whn) + np.dot(u_t, Wxn) + bn)

        h_t = (1-z_t) * h_t_1 + z_t * n_t
        # h,z,r,nの値保持
        hs[t] = h_t
        zs[t] = z_t
        rs[t] = r_t
        ns[t] = n_t
        h_t_1 = h_t
    # 後設定
    h_n = h_t
    hs = hs.transpose(1,0,2)
    return h_n, hs, zs, rs, ns

def GRU_back(dh_n, dhs, u, hs, h_0, zs, rs, ns, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh):
    propagation_back_func = eval(propagation_func.__name__ + "_back")
    # dh_n - (n, h)
    # dhs - (n, t, h)
    dhs = dhs.transpose(1,0,2)
    # u  - (n, t, d)
    un, ut, ud = u.shape
    u = u.transpose(1,0,2)
    # hs - (n, t, h)
    hs = hs.transpose(1,0,2)
    # h_0 - (n, h)
    # zs, rs, ns - (n, t, h)
    # Wxz, Wxr, Wxn - (d, h)
    # Whz, Whr, Whn - (h, h)
    # bz,  br,  bn  - h
    # 初期設定
    dh_t  = dh_n
    du  = np.zeros_like(u)
    dWxz = np.zeros_like(Wxz)
    dWhz = np.zeros_like(Whz)
    dbz  = np.zeros_like(bz)
    dWxr = np.zeros_like(Wxr)
    dWhr = np.zeros_like(Whr)
    dbr  = np.zeros_like(br)
    dWxn = np.zeros_like(Wxn)
    dWhn = np.zeros_like(Whn)
    dbn  = np.zeros_like(bn)
    for t in reversed(range(ut)):
        # GRU勾配計算
        u_t = u[t]
        h_t = hs[t]
        if t == 0:
            h_t_1 = h_0
        else:
            h_t_1 = hs[t-1]
        z_t = zs[t]
        r_t = rs[t]
        n_t = ns[t]
        # 両方向の勾配を加える
        dh_t = dh_t + dhs[t]
        # 各勾配
        # n
        dn = dh_t * z_t
        dpn = propagation_back_func(dn, None, n_t)
        dbn  += np.sum(dpn, axis=0)
        dWhn += np.dot(h_t_1.T, dpn * r_t)
        dWxn += np.dot(u_t.T, dpn)
        # r
        dr = dpn * r_t * np.dot(h_t_1, Whn)
        dpr = sigmoid_back(dr, None, r_t)
        dbr  += np.sum(dpr, axis=0)
        dWhr += np.dot(h_t_1.T, dpr) 
        dWxr += np.dot(u_t.T, dpr)
        # z
        dz = dh_t * n_t - dh_t * h_t_1
        dpz = sigmoid_back(dz, None, z_t)
        dbz  += np.sum(dpz, axis=0)
        dWhz += np.dot(h_t_1.T, dpz) 
        dWxz += np.dot(u_t.T, dpz)
        # h
        dh_t_1 = np.dot(dpz, Whz.T) + np.dot(dpr, Whr.T) + np.dot(dpn * r_t, Whn.T) + dh_t * (1 - z_t)
        # u
        du_t = np.dot(dpz, Wxz.T) + np.dot(dpr, Wxr.T) + np.dot(dpn, Wxn.T)
        dh_t = dh_t_1
        # uの勾配の値保持
        du[t] = du_t
    # 後設定
    dh_0 = dh_t_1
    du = du.transpose(1,0,2)

    return du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0
def RNN_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wx = Wx_init_func(d, h, **Wx_init_params)
    Wh = Wh_init_func(h, h, **Wh_init_params)
    b = bias_init_func(h, **bias_init_params)
    return d_next, {"Wx":Wx, "Wh":Wh, "b":b}

def RNN_init_optimizer():
    sWx = {}
    sWh = {}
    sb = {}
    return {"sWx":sWx, "sWh":sWh, "sb":sb}

def RNN_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Wh"].shape[0]))
    # RNN
    h_n, hs = func(u, h_0, weights["Wx"], weights["Wh"], weights["b"], propagation_func)
    # RNN最下層以外    
    if params.get("return_hs"):
        z = hs
    # RNN最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wx"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wh"], **weight_decay["params"])
    return {"u":u, "z":z, "h_0":h_0, "hs":hs}, weight_decay_r

def RNN_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # RNN最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"])
        dhs = dz
    # RNN最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"])
    # RNNの勾配計算
    du, dWx, dWh, db, dh_0 = back_func(dh_n, dhs, us["u"], us["hs"], us["h_0"], weights["Wx"], weights["Wh"], weights["b"], propagation_func)
    # 重み減衰対応
    if weight_decay is not None:
        dWx += weight_decay["back_func"](weights["Wx"], **weight_decay["params"])
        dWh += weight_decay["back_func"](weights["Wh"], **weight_decay["params"])
    return {"du":du, "dWx":dWx, "dWh":dWh, "db":db}

def RNN_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wx"], optimizer_stats["sWx"] = func(weights["Wx"], du["dWx"], **params, **optimizer_stats["sWx"])
    weights["Wh"], optimizer_stats["sWh"] = func(weights["Wh"], du["dWh"], **params, **optimizer_stats["sWh"])
    weights["b"],  optimizer_stats["sb"]  = func(weights["b"],  du["db"],  **params, **optimizer_stats["sb"])
    return weights, optimizer_stats

def LSTM_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wxf = Wx_init_func(d, h, **Wx_init_params)
    Whf = Wh_init_func(h, h, **Wh_init_params)
    bf = bias_init_func(h, **bias_init_params)
    Wxi = Wx_init_func(d, h, **Wx_init_params)
    Whi = Wh_init_func(h, h, **Wh_init_params)
    bi = bias_init_func(h, **bias_init_params)
    Wxg = Wx_init_func(d, h, **Wx_init_params)
    Whg = Wh_init_func(h, h, **Wh_init_params)
    bg = bias_init_func(h, **bias_init_params)
    Wxo = Wx_init_func(d, h, **Wx_init_params)
    Who = Wh_init_func(h, h, **Wh_init_params)
    bo = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxf":Wxf, "Whf":Whf, "bf":bf, "Wxi":Wxi, "Whi":Whi, "bi":bi, "Wxg":Wxg, "Whg":Whg, "bg":bg, "Wxo":Wxo, "Who":Who, "bo":bo}

def LSTM_init_optimizer():
    sWxf = {}
    sWhf = {}
    sbf = {}
    sWxi = {}
    sWhi = {}
    sbi = {}
    sWxg = {}
    sWhg = {}
    sbg = {}
    sWxo = {}
    sWho = {}
    sbo = {}
    return {"sWxf":sWxf, "sWhf":sWhf, "sbf":sbf,
             "sWxi":sWxi, "sWhi":sWhi, "sbi":sbi,
             "sWxg":sWxg, "sWhg":sWhg, "sbg":sbg,
             "sWxo":sWxo, "sWho":sWho, "sbo":sbo}

def LSTM_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Whf"].shape[0]))
    c_0 = params.get("c_0")
    if c_0 is None:
        c_0 = np.zeros((u.shape[0], weights["Whf"].shape[0]))
    # LSTM
    h_n, hs, c_n, cs, fs, Is, gs, os = func(u, h_0, c_0,
                                        weights["Wxf"], weights["Whf"], weights["bf"], 
                                        weights["Wxi"], weights["Whi"], weights["bi"], 
                                        weights["Wxg"], weights["Whg"], weights["bg"], 
                                        weights["Wxo"], weights["Who"], weights["bo"], propagation_func)
    # LSTM最下層以外 
    if params.get("return_hs"):
        z = hs
    # LSTM最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wxf"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whf"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxi"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whi"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxg"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whg"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxo"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Who"], **weight_decay["params"])

    return {"u":u, "z":z, "h_0":h_0, "hs":hs, "c_0":c_0, "cs":cs,
             "fs":fs, "Is":Is, "gs":gs, "os":os}, weight_decay_r

def LSTM_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # LSTM最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"]) 
        dhs = dz
    # LSTM最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"]) 
    # LSTMの勾配計算
    du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0 = back_func(dh_n, dhs,
                                        us["u"], us["hs"], us["h_0"], us["cs"], us["c_0"],
                                        us["fs"], us["Is"], us["gs"], us["os"],
                                        weights["Wxf"], weights["Whf"], weights["bf"],
                                        weights["Wxi"], weights["Whi"], weights["bi"],
                                        weights["Wxg"], weights["Whg"], weights["bg"],
                                        weights["Wxo"], weights["Who"], weights["bo"],
                                        propagation_func)

    # 重み減衰対応
    if weight_decay is not None:
        dWxf += weight_decay["back_func"](weights["Wxf"], **weight_decay["params"])
        dWhf += weight_decay["back_func"](weights["Whf"], **weight_decay["params"])
        dWxi += weight_decay["back_func"](weights["Wxi"], **weight_decay["params"])
        dWhi += weight_decay["back_func"](weights["Whi"], **weight_decay["params"])
        dWxg += weight_decay["back_func"](weights["Wxg"], **weight_decay["params"])
        dWhg += weight_decay["back_func"](weights["Whg"], **weight_decay["params"])
        dWxo += weight_decay["back_func"](weights["Wxo"], **weight_decay["params"])
        dWho += weight_decay["back_func"](weights["Who"], **weight_decay["params"])

    return {"du":du, "dWxf":dWxf, "dWhf":dWhf, "dbf":dbf, "dWxi":dWxi, "dWhi":dWhi, "dbi":dbi, "dWxg":dWxg, "dWhg":dWhg, "dbg":dbg, "dWxo":dWxo, "dWho":dWho, "dbo":dbo}

def LSTM_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wxf"], optimizer_stats["sWxf"] = func(weights["Wxf"], du["dWxf"], **params, **optimizer_stats["sWxf"])
    weights["Whf"], optimizer_stats["sWhf"] = func(weights["Whf"], du["dWhf"], **params, **optimizer_stats["sWhf"])
    weights["bf"],  optimizer_stats["sbf"]  = func(weights["bf"],  du["dbf"],  **params, **optimizer_stats["sbf"])
    weights["Wxi"], optimizer_stats["sWxi"] = func(weights["Wxi"], du["dWxi"], **params, **optimizer_stats["sWxi"])
    weights["Whi"], optimizer_stats["sWhi"] = func(weights["Whi"], du["dWhi"], **params, **optimizer_stats["sWhi"])
    weights["bi"],  optimizer_stats["sbi"]  = func(weights["bi"],  du["dbi"],  **params, **optimizer_stats["sbi"])
    weights["Wxg"], optimizer_stats["sWxg"] = func(weights["Wxg"], du["dWxg"], **params, **optimizer_stats["sWxg"])
    weights["Whg"], optimizer_stats["sWhg"] = func(weights["Whg"], du["dWhg"], **params, **optimizer_stats["sWhg"])
    weights["bg"],  optimizer_stats["sbg"]  = func(weights["bg"],  du["dbg"],  **params, **optimizer_stats["sbg"])
    weights["Wxo"], optimizer_stats["sWxo"] = func(weights["Wxo"], du["dWxo"], **params, **optimizer_stats["sWxo"])
    weights["Who"], optimizer_stats["sWho"] = func(weights["Who"], du["dWho"], **params, **optimizer_stats["sWho"])
    weights["bo"],  optimizer_stats["sbo"]  = func(weights["bo"],  du["dbo"],  **params, **optimizer_stats["sbo"])
    return weights, optimizer_stats

def GRU_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
    t, d = d_prev
    d_next = (t, h)
    if not params.get("return_hs"):
        d_next = h
    Wxz = Wx_init_func(d, h, **Wx_init_params)
    Whz = Wh_init_func(h, h, **Wh_init_params)
    bz = bias_init_func(h, **bias_init_params)
    Wxr = Wx_init_func(d, h, **Wx_init_params)
    Whr = Wh_init_func(h, h, **Wh_init_params)
    br = bias_init_func(h, **bias_init_params)
    Wxn = Wx_init_func(d, h, **Wx_init_params)
    Whn = Wh_init_func(h, h, **Wh_init_params)
    bn = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxz":Wxz, "Whz":Whz, "bz":bz, "Wxr":Wxr, "Whr":Whr, "br":br, "Wxn":Wxn, "Whn":Whn, "bn":bn}

def GRU_init_optimizer():
    sWxz = {}
    sWhz = {}
    sbz = {}
    sWxr = {}
    sWhr = {}
    sbr = {}
    sWxn = {}
    sWhn = {}
    sbn = {}
    return {"sWxz":sWxz, "sWhz":sWhz, "sbz":sbz,
             "sWxr":sWxr, "sWhr":sWhr, "sbr":sbr,
             "sWxn":sWxn, "sWhn":sWhn, "sbn":sbn}

def GRU_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = np.zeros((u.shape[0], weights["Whz"].shape[0]))
    # GRU
    h_n, hs, zs, rs, ns = func(u, h_0,
                               weights["Wxz"], weights["Whz"], weights["bz"], 
                               weights["Wxr"], weights["Whr"], weights["br"], 
                               weights["Wxn"], weights["Whn"], weights["bn"], 
                               propagation_func)
    # GRU最下層以外
    if params.get("return_hs"):
        z = hs
    # GRU最下層
    else:
        z = h_n
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        weight_decay_r += weight_decay["func"](weights["Wxz"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whz"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxr"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whr"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Wxn"], **weight_decay["params"])
        weight_decay_r += weight_decay["func"](weights["Whn"], **weight_decay["params"])

    return {"u":u, "z":z, "h_0":h_0, "hs":hs, "zs":zs, "rs":rs, "ns":ns}, weight_decay_r

def GRU_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
    # GRU最下層以外
    if return_hs:
        dh_n = np.zeros_like(us["h_0"]) 
        dhs = dz
    # GRU最下層
    else:
        dh_n = dz
        dhs = np.zeros_like(us["hs"]) 
    # GRUの勾配計算
    du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0 = back_func(dh_n, dhs,
                                        us["u"], us["hs"], us["h_0"],
                                        us["zs"], us["rs"], us["ns"],
                                        weights["Wxz"], weights["Whz"], weights["bz"],
                                        weights["Wxr"], weights["Whr"], weights["br"],
                                        weights["Wxn"], weights["Whn"], weights["bn"],
                                        propagation_func)

    # 重み減衰対応
    if weight_decay is not None:
        dWxz += weight_decay["back_func"](weights["Wxz"], **weight_decay["params"])
        dWhz += weight_decay["back_func"](weights["Whz"], **weight_decay["params"])
        dWxr += weight_decay["back_func"](weights["Wxr"], **weight_decay["params"])
        dWhr += weight_decay["back_func"](weights["Whr"], **weight_decay["params"])
        dWxn += weight_decay["back_func"](weights["Wxn"], **weight_decay["params"])
        dWhn += weight_decay["back_func"](weights["Whn"], **weight_decay["params"])

    return {"du":du, "dWxz":dWxz, "dWhz":dWhz, "dbz":dbz, "dWxr":dWxr, "dWhr":dWhr, "dbr":dbr, "dWxn":dWxn, "dWhn":dWhn, "dbn":dbn}

def GRU_update_weight(func, du, weights, optimizer_stats, **params):
    weights["Wxz"], optimizer_stats["sWxz"] = func(weights["Wxz"], du["dWxz"], **params, **optimizer_stats["sWxz"])
    weights["Whz"], optimizer_stats["sWhz"] = func(weights["Whz"], du["dWhz"], **params, **optimizer_stats["sWhz"])
    weights["bz"],  optimizer_stats["sbz"]  = func(weights["bz"],  du["dbz"],  **params, **optimizer_stats["sbz"])
    weights["Wxr"], optimizer_stats["sWxr"] = func(weights["Wxr"], du["dWxr"], **params, **optimizer_stats["sWxr"])
    weights["Whr"], optimizer_stats["sWhr"] = func(weights["Whr"], du["dWhr"], **params, **optimizer_stats["sWhr"])
    weights["br"],  optimizer_stats["sbr"]  = func(weights["br"],  du["dbr"],  **params, **optimizer_stats["sbr"])
    weights["Wxn"], optimizer_stats["sWxn"] = func(weights["Wxn"], du["dWxn"], **params, **optimizer_stats["sWxn"])
    weights["Whn"], optimizer_stats["sWhn"] = func(weights["Whn"], du["dWhn"], **params, **optimizer_stats["sWhn"])
    weights["bn"],  optimizer_stats["sbn"]  = func(weights["bn"],  du["dbn"],  **params, **optimizer_stats["sbn"])
    return weights, optimizer_stats
def BiRNN(u, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    # 初期化
    h_n, hs = {}, {}
    # 順方向
    h_n["f"], hs["f"] = RNN(u_dr["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
    # 逆方向
    h_n["r"], hs["r"] = RNN(u_dr["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
    # h_n - (n, h)
    # hs - (n, t, h)
    if merge == "concat":
        # 結合
        h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
        hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
    elif merge == "sum":
        # 合計
        h_n_merge = h_n["f"] + h_n["r"]
        hs_merge = hs["f"] + hs["r"][:, ::-1]
    elif merge == "dictionary":
        # 辞書
        h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
        hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
    return h_n_merge, hs_merge, h_n, hs

def BiRNN_back(dh_n_merge, dhs_merge, u, hs, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    h = Wh["f"].shape[0]
    # dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
    # dhs - (n, t, h)
    dh_n, dhs = {}, {}
    if merge == "concat":
        # 結合
        dh_n["f"] = dh_n_merge[:, :h]
        dh_n["r"] = dh_n_merge[:, h:]
        dhs["f"] = dhs_merge[:, :, :h]
        dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
    elif merge == "sum":
        # 合計
        dh_n["f"] = dh_n_merge
        dh_n["r"] = dh_n_merge
        dhs["f"] = dhs_merge
        dhs["r"] = dhs_merge[:, ::-1]
    elif merge == "dictionary":
        # 辞書
        dh_n["f"] = dh_n_merge["f"]
        dh_n["r"] = dh_n_merge["r"]
        dhs["f"] = dhs_merge["f"]
        dhs["r"] = dhs_merge["r"][:, ::-1]
    # 初期化
    du = {}
    dWx, dWh, db = {}, {}, {}
    dh_0 = {}
    # 順方向
    du["f"], dWx["f"], dWh["f"], db["f"], dh_0["f"] = RNN_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
    # 逆方向
    du["r"], dWx["r"], dWh["r"], db["r"], dh_0["r"] = RNN_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
    # du - (n, t, d)
    if not input_dictionary:
        du_merge = du["f"] + du["r"][:, ::-1]
    else:
        du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
    return du_merge, dWx, dWh, db, dh_0

def BiLSTM(u, h_0, c_0, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    # 初期化
    h_n, hs, c_n, cs = {}, {}, {}, {}
    fs, Is, gs, os = {}, {}, {}, {}
    # 順方向
    h_n["f"], hs["f"], c_n["f"], cs["f"], fs["f"], Is["f"], gs["f"], os["f"] = LSTM(u_dr["f"], h_0["f"], c_0["f"], Wxf["f"], Whf["f"], bf["f"], Wxi["f"], Whi["f"], bi["f"], Wxg["f"], Whg["f"], bg["f"], Wxo["f"], Who["f"], bo["f"], propagation_func)
    # 逆方向
    h_n["r"], hs["r"], c_n["r"], cs["r"], fs["r"], Is["r"], gs["r"], os["r"] = LSTM(u_dr["r"], h_0["r"], c_0["r"], Wxf["r"], Whf["r"], bf["r"], Wxi["r"], Whi["r"], bi["r"], Wxg["r"], Whg["r"], bg["r"], Wxo["r"], Who["r"], bo["r"], propagation_func)
    # h_n, c_n - (n, h)
    # hs, cs - (n, t, h)
    if merge == "concat":
        # 結合
        h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
        hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
    elif merge == "sum":
        # 合計
        h_n_merge = h_n["f"] + h_n["r"]
        hs_merge = hs["f"] + hs["r"][:, ::-1]
    elif merge == "dictionary":
        # 辞書
        h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
        hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
    return h_n_merge, hs_merge, h_n, hs, c_n, cs, fs, Is, gs, os

def BiLSTM_back(dh_n_merge, dhs_merge, u, hs, h_0, cs, c_0, fs, Is, gs, os, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    h = Whf["f"].shape[0]
    # dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
    # dhs - (n, t, h)
    dh_n, dhs = {}, {}
    if merge == "concat":
        # 結合
        dh_n["f"] = dh_n_merge[:, :h]
        dh_n["r"] = dh_n_merge[:, h:]
        dhs["f"] = dhs_merge[:, :, :h]
        dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
    elif merge == "sum":
        # 合計
        dh_n["f"] = dh_n_merge
        dh_n["r"] = dh_n_merge
        dhs["f"] = dhs_merge
        dhs["r"] = dhs_merge[:, ::-1]
    elif merge == "dictionary":
        # 辞書
        dh_n["f"] = dh_n_merge["f"]
        dh_n["r"] = dh_n_merge["r"]
        dhs["f"] = dhs_merge["f"]
        dhs["r"] = dhs_merge["r"][:, ::-1]
    # 初期化
    du = {}
    dWxf, dWhf, dbf = {}, {}, {}
    dWxi, dWhi, dbi = {}, {}, {}
    dWxg, dWhg, dbg = {}, {}, {}
    dWxo, dWho, dbo = {}, {}, {}
    dh_0, dc_0 = {}, {}
    # 順方向
    du["f"], dWxf["f"], dWhf["f"], dbf["f"], dWxi["f"], dWhi["f"], dbi["f"], dWxg["f"], dWhg["f"], dbg["f"], dWxo["f"], dWho["f"], dbo["f"], dh_0["f"], dc_0["f"] = \
        LSTM_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], cs["f"], c_0["f"], fs["f"], Is["f"], gs["f"], os["f"], Wxf["f"], Whf["f"], bf["f"], Wxi["f"], Whi["f"], bi["f"], Wxg["f"], Whg["f"], bg["f"], Wxo["f"], Who["f"], bo["f"], propagation_func)
    # 逆方向
    du["r"], dWxf["r"], dWhf["r"], dbf["r"], dWxi["r"], dWhi["r"], dbi["r"], dWxg["r"], dWhg["r"], dbg["r"], dWxo["r"], dWho["r"], dbo["r"], dh_0["r"], dc_0["r"] = \
        LSTM_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], cs["r"], c_0["r"], fs["r"], Is["r"], gs["r"], os["r"], Wxf["r"], Whf["r"], bf["r"], Wxi["r"], Whi["r"], bi["r"], Wxg["r"], Whg["r"], bg["r"], Wxo["r"], Who["r"], bo["r"], propagation_func)
    # du - (n, t, d)
    if not input_dictionary:
        du_merge = du["f"] + du["r"][:, ::-1]
    else:
        du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
    return du_merge, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0

def BiGRU(u, h_0, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    # 初期化
    h_n, hs = {}, {}
    zs, rs, ns = {}, {}, {}
    # 順方向
    h_n["f"], hs["f"], zs["f"], rs["f"], ns["f"] = GRU(u_dr["f"], h_0["f"], Wxz["f"], Whz["f"], bz["f"], Wxr["f"], Whr["f"], br["f"], Wxn["f"], Whn["f"], bn["f"], propagation_func)
    # 逆方向
    h_n["r"], hs["r"], zs["r"], rs["r"], ns["r"] = GRU(u_dr["r"], h_0["r"], Wxz["r"], Whz["r"], bz["r"], Wxr["r"], Whr["r"], br["r"], Wxn["r"], Whn["r"], bn["r"], propagation_func)
    # h_n - (n, h)
    # hs - (n, t, h)
    if merge == "concat":
        # 結合
        h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
        hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
    elif merge == "sum":
        # 合計
        h_n_merge = h_n["f"] + h_n["r"]
        hs_merge = hs["f"] + hs["r"][:, ::-1]
    elif merge == "dictionary":
        # 辞書
        h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
        hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
    return h_n_merge, hs_merge, h_n, hs, zs, rs, ns

def BiGRU_back(dh_n_merge, dhs_merge, u, hs, h_0, zs, rs, ns, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh, merge="concat", input_dictionary=False):
    # データ設定(逆方向は、時刻を反転)
    u_dr = {}
    if not input_dictionary:
        u_dr["f"] = u
        u_dr["r"] = u[:, ::-1]
    else:
        u_dr["f"] = u["f"]
        u_dr["r"] = u["r"][:, ::-1]
    h = Whz["f"].shape[0]
    # dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
    # dhs - (n, t, h)
    dh_n, dhs = {}, {}
    if merge == "concat":
        # 結合
        dh_n["f"] = dh_n_merge[:, :h]
        dh_n["r"] = dh_n_merge[:, h:]
        dhs["f"] = dhs_merge[:, :, :h]
        dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
    elif merge == "sum":
        # 合計
        dh_n["f"] = dh_n_merge
        dh_n["r"] = dh_n_merge
        dhs["f"] = dhs_merge
        dhs["r"] = dhs_merge[:, ::-1]
    elif merge == "dictionary":
        # 辞書
        dh_n["f"] = dh_n_merge["f"]
        dh_n["r"] = dh_n_merge["r"]
        dhs["f"] = dhs_merge["f"]
        dhs["r"] = dhs_merge["r"][:, ::-1]
    # 初期化
    du = {}
    dWxz, dWhz, dbz = {}, {}, {}
    dWxr, dWhr, dbr = {}, {}, {}
    dWxn, dWhn, dbn = {}, {}, {}
    dh_0 = {}
    # 順方向
    du["f"], dWxz["f"], dWhz["f"], dbz["f"], dWxr["f"], dWhr["f"], dbr["f"], dWxn["f"], dWhn["f"], dbn["f"], dh_0["f"] = \
        GRU_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], zs["f"], rs["f"], ns["f"], Wxz["f"], Whz["f"], bz["f"], Wxr["f"], Whr["f"], br["f"], Wxn["f"], Whn["f"], bn["f"], propagation_func)
    # 逆方向
    du["r"], dWxz["r"], dWhz["r"], dbz["r"], dWxr["r"], dWhr["r"], dbr["r"], dWxn["r"], dWhn["r"], dbn["r"], dh_0["r"] = \
        GRU_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], zs["r"], rs["r"], ns["r"], Wxz["r"], Whz["r"], bz["r"], Wxr["r"], Whr["r"], br["r"], Wxn["r"], Whn["r"], bn["r"], propagation_func)
    # du - (n, t, d)
    if not input_dictionary:
        du_merge = du["f"] + du["r"][:, ::-1]
    else:
        du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
    return du_merge, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0
def BiRNN_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
    t, d = d_prev
    h_next = h
    # concatの場合、出力は2倍
    if merge == "concat":
        h_next = h * 2
    d_next = (t, h_next)
    if not params.get("return_hs"):
        d_next = h_next
    Wx, Wh, b = {}, {}, {}
    for dr in ["f", "r"]:
        Wx[dr] = Wx_init_func(d, h, **Wx_init_params)
        Wh[dr] = Wh_init_func(h, h, **Wh_init_params)
        b[dr]  = bias_init_func(h, **bias_init_params)
    return d_next, {"Wx":Wx, "Wh":Wh, "b":b}

def BiRNN_init_optimizer():
    sWx, sWh, sb = {}, {}, {}
    for dr in ["f", "r"]:
        sWx[dr] = {}
        sWh[dr] = {}
        sb[dr] = {}
    return {"sWx":sWx, "sWh":sWh, "sb":sb}

def BiRNN_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                h_0[dr] = np.zeros((u.shape[0], weights["Wh"][dr].shape[0]))
            else:
                h_0[dr] = np.zeros((u[dr].shape[0], weights["Wh"][dr].shape[0]))
    # RNN
    h_n_merge, hs_merge, h_n, hs = func(u, h_0, weights["Wx"], weights["Wh"], weights["b"], propagation_func, merge, input_dictionary)
    # RNN最下層以外    
    if params.get("return_hs"):
        z = hs_merge
    # RNN最下層
    else:
        z = h_n_merge
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        for dr in ["f", "r"]:
            weight_decay_r += weight_decay["func"](weights["Wx"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wh"][dr], **weight_decay["params"])
    return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs}, weight_decay_r

def BiRNN_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):
    # RNN最下層以外
    if return_hs:
        if merge != "dictionary":
            dh_n_merge = np.zeros_like(us["h_n_merge"])
        else:
            dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
        dhs_merge = dz
    # RNN最下層
    else:
        dh_n_merge = dz
        if merge != "dictionary":
            dhs_merge = np.zeros_like(us["hs_merge"])
        else:
            dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
    # RNNの勾配計算
    du, dWx, dWh, db, dh_0 = back_func(dh_n_merge, dhs_merge, us["u"], us["hs"], us["h_0"], weights["Wx"], weights["Wh"], weights["b"], propagation_func, merge, input_dictionary)
    # 重み減衰対応
    if weight_decay is not None:
        for dr in ["f", "r"]:
            dWx[dr] += weight_decay["back_func"](weights["Wx"][dr], **weight_decay["params"])
            dWh[dr] += weight_decay["back_func"](weights["Wh"][dr], **weight_decay["params"])
    return {"du":du, "dWx":dWx, "dWh":dWh, "db":db}

def BiRNN_update_weight(func, du, weights, optimizer_stats, **params):
    for dr in ["f", "r"]:
        weights["Wx"][dr], optimizer_stats["sWx"][dr] = func(weights["Wx"][dr], du["dWx"][dr], **params, **optimizer_stats["sWx"][dr])
        weights["Wh"][dr], optimizer_stats["sWh"][dr] = func(weights["Wh"][dr], du["dWh"][dr], **params, **optimizer_stats["sWh"][dr])
        weights["b"][dr],  optimizer_stats["sb"][dr]  = func(weights["b"][dr],  du["db"][dr],  **params, **optimizer_stats["sb"][dr])
    return weights, optimizer_stats

def BiLSTM_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
    t, d = d_prev
    h_next = h
    # concatの場合、出力は2倍
    if merge == "concat":
        h_next = h * 2
    d_next = (t, h_next)
    if not params.get("return_hs"):
        d_next = h_next
    Wxf, Whf, bf = {}, {}, {}
    Wxi, Whi, bi = {}, {}, {}
    Wxg, Whg, bg = {}, {}, {}
    Wxo, Who, bo = {}, {}, {}
    for dr in ["f", "r"]:
        Wxf[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whf[dr] = Wh_init_func(h, h, **Wh_init_params)
        bf[dr]  = bias_init_func(h, **bias_init_params)
        Wxi[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whi[dr] = Wh_init_func(h, h, **Wh_init_params)
        bi[dr]  = bias_init_func(h, **bias_init_params)
        Wxg[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whg[dr] = Wh_init_func(h, h, **Wh_init_params)
        bg[dr]  = bias_init_func(h, **bias_init_params)
        Wxo[dr] = Wx_init_func(d, h, **Wx_init_params)
        Who[dr] = Wh_init_func(h, h, **Wh_init_params)
        bo[dr]  = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxf":Wxf, "Whf":Whf, "bf":bf, "Wxi":Wxi, "Whi":Whi, "bi":bi, "Wxg":Wxg, "Whg":Whg, "bg":bg, "Wxo":Wxo, "Who":Who, "bo":bo}

def BiLSTM_init_optimizer():
    sWxf, sWhf, sbf = {}, {}, {}
    sWxi, sWhi, sbi = {}, {}, {}
    sWxg, sWhg, sbg = {}, {}, {}
    sWxo, sWho, sbo = {}, {}, {}
    for dr in ["f", "r"]:
        sWxf[dr] = {}
        sWhf[dr] = {}
        sbf[dr]  = {}
        sWxi[dr] = {}
        sWhi[dr] = {}
        sbi[dr]  = {}
        sWxg[dr] = {}
        sWhg[dr] = {}
        sbg[dr]  = {}
        sWxo[dr] = {}
        sWho[dr] = {}
        sbo[dr]  = {}
    return {"sWxf":sWxf, "sWhf":sWhf, "sbf":sbf,
             "sWxi":sWxi, "sWhi":sWhi, "sbi":sbi,
             "sWxg":sWxg, "sWhg":sWhg, "sbg":sbg,
             "sWxo":sWxo, "sWho":sWho, "sbo":sbo}

def BiLSTM_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                h_0[dr] = np.zeros((u.shape[0], weights["Whf"][dr].shape[0]))
            else:
                h_0[dr] = np.zeros((u[dr].shape[0], weights["Whf"][dr].shape[0]))
    c_0 = params.get("c_0")
    if c_0 is None:
        c_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                c_0[dr] = np.zeros((u.shape[0], weights["Whf"][dr].shape[0]))
            else:
                c_0[dr] = np.zeros((u[dr].shape[0], weights["Whf"][dr].shape[0]))
    # LSTM
    h_n_merge, hs_merge, h_n, hs, c_n, cs, fs, Is, gs, os = func(u, h_0, c_0,
                                                                 weights["Wxf"], weights["Whf"], weights["bf"],
                                                                 weights["Wxi"], weights["Whi"], weights["bi"],
                                                                 weights["Wxg"], weights["Whg"], weights["bg"],
                                                                 weights["Wxo"], weights["Who"], weights["bo"],
                                                                 propagation_func, merge, input_dictionary)
    # LSTM最下層以外 
    if params.get("return_hs"):
        z = hs_merge
    # LSTM最下層
    else:
        z = h_n_merge
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        for dr in ["f", "r"]:
            weight_decay_r += weight_decay["func"](weights["Wxf"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whf"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxi"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whi"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxg"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whg"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxo"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Who"][dr], **weight_decay["params"])

    return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs, "c_0":c_0, "c_n":c_n, "cs":cs,
             "fs":fs, "Is":Is, "gs":gs, "os":os}, weight_decay_r

def BiLSTM_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):    
    # LSTM最下層以外
    if return_hs:
        if merge != "dictionary":
            dh_n_merge = np.zeros_like(us["h_n_merge"])
        else:
            dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
        dhs_merge = dz
    # LSTM最下層
    else:
        dh_n_merge = dz
        if merge != "dictionary":
            dhs_merge = np.zeros_like(us["hs_merge"])
        else:
            dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
    # LSTMの勾配計算
    du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0 = back_func(dh_n_merge, dhs_merge,
                                        us["u"], us["hs"], us["h_0"], us["cs"], us["c_0"],
                                        us["fs"], us["Is"], us["gs"], us["os"],
                                        weights["Wxf"], weights["Whf"], weights["bf"],
                                        weights["Wxi"], weights["Whi"], weights["bi"],
                                        weights["Wxg"], weights["Whg"], weights["bg"],
                                        weights["Wxo"], weights["Who"], weights["bo"],
                                        propagation_func, merge, input_dictionary)

    # 重み減衰対応
    if weight_decay is not None:
        for dr in ["f", "r"]:
            dWxf[dr] += weight_decay["back_func"](weights["Wxf"][dr], **weight_decay["params"])
            dWhf[dr] += weight_decay["back_func"](weights["Whf"][dr], **weight_decay["params"])
            dWxi[dr] += weight_decay["back_func"](weights["Wxi"][dr], **weight_decay["params"])
            dWhi[dr] += weight_decay["back_func"](weights["Whi"][dr], **weight_decay["params"])
            dWxg[dr] += weight_decay["back_func"](weights["Wxg"][dr], **weight_decay["params"])
            dWhg[dr] += weight_decay["back_func"](weights["Whg"][dr], **weight_decay["params"])
            dWxo[dr] += weight_decay["back_func"](weights["Wxo"][dr], **weight_decay["params"])
            dWho[dr] += weight_decay["back_func"](weights["Who"][dr], **weight_decay["params"])

    return {"du":du, "dWxf":dWxf, "dWhf":dWhf, "dbf":dbf, "dWxi":dWxi, "dWhi":dWhi, "dbi":dbi, "dWxg":dWxg, "dWhg":dWhg, "dbg":dbg, "dWxo":dWxo, "dWho":dWho, "dbo":dbo}

def BiLSTM_update_weight(func, du, weights, optimizer_stats, **params):
    for dr in ["f", "r"]:
        weights["Wxf"][dr], optimizer_stats["sWxf"][dr] = func(weights["Wxf"][dr], du["dWxf"][dr], **params, **optimizer_stats["sWxf"][dr])
        weights["Whf"][dr], optimizer_stats["sWhf"][dr] = func(weights["Whf"][dr], du["dWhf"][dr], **params, **optimizer_stats["sWhf"][dr])
        weights["bf"][dr],  optimizer_stats["sbf"][dr]  = func(weights["bf"][dr],  du["dbf"][dr],  **params, **optimizer_stats["sbf"][dr])
        weights["Wxi"][dr], optimizer_stats["sWxi"][dr] = func(weights["Wxi"][dr], du["dWxi"][dr], **params, **optimizer_stats["sWxi"][dr])
        weights["Whi"][dr], optimizer_stats["sWhi"][dr] = func(weights["Whi"][dr], du["dWhi"][dr], **params, **optimizer_stats["sWhi"][dr])
        weights["bi"][dr],  optimizer_stats["sbi"][dr]  = func(weights["bi"][dr],  du["dbi"][dr],  **params, **optimizer_stats["sbi"][dr])
        weights["Wxg"][dr], optimizer_stats["sWxg"][dr] = func(weights["Wxg"][dr], du["dWxg"][dr], **params, **optimizer_stats["sWxg"][dr])
        weights["Whg"][dr], optimizer_stats["sWhg"][dr] = func(weights["Whg"][dr], du["dWhg"][dr], **params, **optimizer_stats["sWhg"][dr])
        weights["bg"][dr],  optimizer_stats["sbg"][dr]  = func(weights["bg"][dr],  du["dbg"][dr],  **params, **optimizer_stats["sbg"][dr])
        weights["Wxo"][dr], optimizer_stats["sWxo"][dr] = func(weights["Wxo"][dr], du["dWxo"][dr], **params, **optimizer_stats["sWxo"][dr])
        weights["Who"][dr], optimizer_stats["sWho"][dr] = func(weights["Who"][dr], du["dWho"][dr], **params, **optimizer_stats["sWho"][dr])
        weights["bo"][dr],  optimizer_stats["sbo"][dr]  = func(weights["bo"][dr],  du["dbo"][dr],  **params, **optimizer_stats["sbo"][dr])
    return weights, optimizer_stats

def BiGRU_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
    t, d = d_prev
    h_next = h
    # concatの場合、出力は2倍
    if merge == "concat":
        h_next = h * 2
    d_next = (t, h_next)
    if not params.get("return_hs"):
        d_next = h_next
    Wxz, Whz, bz = {}, {}, {}
    Wxr, Whr, br = {}, {}, {}
    Wxn, Whn, bn = {}, {}, {}
    for dr in ["f", "r"]:
        Wxz[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whz[dr] = Wh_init_func(h, h, **Wh_init_params)
        bz[dr]  = bias_init_func(h, **bias_init_params)
        Wxr[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whr[dr] = Wh_init_func(h, h, **Wh_init_params)
        br[dr]  = bias_init_func(h, **bias_init_params)
        Wxn[dr] = Wx_init_func(d, h, **Wx_init_params)
        Whn[dr] = Wh_init_func(h, h, **Wh_init_params)
        bn[dr]  = bias_init_func(h, **bias_init_params)
    return d_next, {"Wxz":Wxz, "Whz":Whz, "bz":bz, "Wxr":Wxr, "Whr":Whr, "br":br, "Wxn":Wxn, "Whn":Whn, "bn":bn}

def BiGRU_init_optimizer():
    sWxz, sWhz, sbz = {}, {}, {}
    sWxr, sWhr, sbr = {}, {}, {}
    sWxn, sWhn, sbn = {}, {}, {}
    for dr in ["f", "r"]:
        sWxz[dr] = {}
        sWhz[dr] = {}
        sbz[dr]  = {}
        sWxr[dr] = {}
        sWhr[dr] = {}
        sbr[dr]  = {}
        sWxn[dr] = {}
        sWhn[dr] = {}
        sbn[dr]  = {}
    return {"sWxz":sWxz, "sWhz":sWhz, "sbz":sbz,
             "sWxr":sWxr, "sWhr":sWhr, "sbr":sbr,
             "sWxn":sWxn, "sWhn":sWhn, "sbn":sbn}

def BiGRU_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
    h_0 = params.get("h_0")
    if h_0 is None:
        h_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                h_0[dr] = np.zeros((u.shape[0], weights["Whz"][dr].shape[0]))
            else:
                h_0[dr] = np.zeros((u[dr].shape[0], weights["Whz"][dr].shape[0]))
    c_0 = params.get("c_0")
    if c_0 is None:
        c_0 = {}
        for dr in ["f", "r"]:
            if not input_dictionary:
                c_0[dr] = np.zeros((u.shape[0], weights["Whz"][dr].shape[0]))
            else:
                c_0[dr] = np.zeros((u[dr].shape[0], weights["Whz"][dr].shape[0]))
    # GRU
    h_n_merge, hs_merge, h_n, hs, zs, rs, ns = func(u, h_0,
                                                    weights["Wxz"], weights["Whz"], weights["bz"],
                                                    weights["Wxr"], weights["Whr"], weights["br"],
                                                    weights["Wxn"], weights["Whn"], weights["bn"],
                                                    propagation_func, merge, input_dictionary)
    # GRU最下層以外 
    if params.get("return_hs"):
        z = hs_merge
    # GRU最下層
    else:
        z = h_n_merge
    # 重み減衰対応
    weight_decay_r = 0
    if weight_decay is not None:
        for dr in ["f", "r"]:
            weight_decay_r += weight_decay["func"](weights["Wxz"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whz"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxr"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whr"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Wxn"][dr], **weight_decay["params"])
            weight_decay_r += weight_decay["func"](weights["Whn"][dr], **weight_decay["params"])

    return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs, "zs":zs, "rs":rs, "ns":ns}, weight_decay_r

def BiGRU_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):    
    # GRU最下層以外
    if return_hs:
        if merge != "dictionary":
            dh_n_merge = np.zeros_like(us["h_n_merge"])
        else:
            dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
        dhs_merge = dz
    # GRU最下層
    else:
        dh_n_merge = dz
        if merge != "dictionary":
            dhs_merge = np.zeros_like(us["hs_merge"])
        else:
            dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
    # GRUの勾配計算
    du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0 = back_func(dh_n_merge, dhs_merge,
                                        us["u"], us["hs"], us["h_0"],
                                        us["zs"], us["rs"], us["ns"],
                                        weights["Wxz"], weights["Whz"], weights["bz"],
                                        weights["Wxr"], weights["Whr"], weights["br"],
                                        weights["Wxn"], weights["Whn"], weights["bn"],
                                        propagation_func, merge, input_dictionary)

    # 重み減衰対応
    if weight_decay is not None:
        for dr in ["f", "r"]:
            dWxz[dr] += weight_decay["back_func"](weights["Wxz"][dr], **weight_decay["params"])
            dWhz[dr] += weight_decay["back_func"](weights["Whz"][dr], **weight_decay["params"])
            dWxr[dr] += weight_decay["back_func"](weights["Wxr"][dr], **weight_decay["params"])
            dWhr[dr] += weight_decay["back_func"](weights["Whr"][dr], **weight_decay["params"])
            dWxn[dr] += weight_decay["back_func"](weights["Wxn"][dr], **weight_decay["params"])
            dWhn[dr] += weight_decay["back_func"](weights["Whn"][dr], **weight_decay["params"])

    return {"du":du, "dWxz":dWxz, "dWhz":dWhz, "dbz":dbz, "dWxr":dWxr, "dWhr":dWhr, "dbr":dbr, "dWxn":dWxn, "dWhn":dWhn, "dbn":dbn}

def BiGRU_update_weight(func, du, weights, optimizer_stats, **params):
    for dr in ["f", "r"]:
        weights["Wxz"][dr], optimizer_stats["sWxz"][dr] = func(weights["Wxz"][dr], du["dWxz"][dr], **params, **optimizer_stats["sWxz"][dr])
        weights["Whz"][dr], optimizer_stats["sWhz"][dr] = func(weights["Whz"][dr], du["dWhz"][dr], **params, **optimizer_stats["sWhz"][dr])
        weights["bz"][dr],  optimizer_stats["sbz"][dr]  = func(weights["bz"][dr],  du["dbz"][dr],  **params, **optimizer_stats["sbz"][dr])
        weights["Wxr"][dr], optimizer_stats["sWxr"][dr] = func(weights["Wxr"][dr], du["dWxr"][dr], **params, **optimizer_stats["sWxr"][dr])
        weights["Whr"][dr], optimizer_stats["sWhr"][dr] = func(weights["Whr"][dr], du["dWhr"][dr], **params, **optimizer_stats["sWhr"][dr])
        weights["br"][dr],  optimizer_stats["sbr"][dr]  = func(weights["br"][dr],  du["dbr"][dr],  **params, **optimizer_stats["sbr"][dr])
        weights["Wxn"][dr], optimizer_stats["sWxn"][dr] = func(weights["Wxn"][dr], du["dWxn"][dr], **params, **optimizer_stats["sWxn"][dr])
        weights["Whn"][dr], optimizer_stats["sWhn"][dr] = func(weights["Whn"][dr], du["dWhn"][dr], **params, **optimizer_stats["sWhn"][dr])
        weights["bn"][dr],  optimizer_stats["sbn"][dr]  = func(weights["bn"][dr],  du["dbn"][dr],  **params, **optimizer_stats["sbn"][dr])
    return weights, optimizer_stats
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?