RNN,LSTM,GRUについて、ハイパーパラメータの確認を行う予定でしたが、その前に、双方向RNNと重みの初期値としてorthogonalの確認を行います。
ディープラーニングを実装から学ぶ(10-1)RNNの実装(RNN,LSTM,GRU)の続きです。
例によって、MNISTで確認していきます。
orthogonal(重みの初期値)
重みの初期値として、活性化関数が$ \mathrm{tanh} $のため、glorot_normalを利用してきました。RNNでは、時刻分、重みの内積を取ることになります。
同じ値を何度もかけるため、勾配爆発や勾配消失の原因になります。
初期値として直交行列を用いることで勾配爆発や勾配消失を軽減します。
直行行列
直行行列は、「転置行列が逆行列と等しい」行列です。
$ UU^T=I $($ I $は、単位行列)が成り立ちます。
直行行列は、特異値分解により求めることができます。
行列Aに対して、直交行列U、V、対角行列Dに分解します。
A=UDV^T
numpyのlinalg.svdにて特異値分解を行うことができます。
linalg.svdでは、U,D,Vが返却されるので、一つ目の戻り値(U)を取得します。元の行列は乱数で求めます。
重みの初期値を直交行列にするための関数です。
def orthogonal(d_1, d):
rand = np.random.normal(0., 1., (d_1, d))
return np.linalg.svd(rand)[0]
行列確認
3$\times$3の行列を生成してみます。
orth = orthogonal(3,3)
orth
以下のような行列が生成されました。
array([[-0.60393137, -0.69761452, 0.38548786],
[ 0.78180171, -0.42438299, 0.45682071],
[-0.15509027, 0.57726342, 0.80169442]])
本当に直交行列かどうか確認するため、転置行列と内積を取ります。
np.dot(orth, orth.T)
結果です。ちゃんと単位行列になっていますね。
array([[ 1.00000000e+00, -3.53711475e-17, -2.99764250e-16],
[-3.53711475e-17, 1.00000000e+00, -8.39352672e-17],
[-2.99764250e-16, -8.39352672e-17, 1.00000000e+00]])
勾配爆発・勾配消失確認
3$ \times $3の直交行列を生成し、MNISTの28時刻分に相当する28回内積を取ってみます。
5回分実行してみます。
for i in range(5):
orth = orthogonal(3,3)
x = orth
print("init ", i+1)
print(x)
for t in range(28):
x = np.dot(orth, x)
print("end ", i+1)
print(x)
28回内積を取ってもそれなりの値になりました。
init 1
[[-0.80372635 0.51087605 -0.30500757]
[-0.36638453 -0.02103647 0.93022569]
[ 0.46881375 0.85939695 0.20408465]]
end 1
[[-0.76245687 -0.40800029 0.50219048]
[ 0.54918583 0.00232497 0.83569703]
[-0.34213221 0.91297884 0.22229529]]
init 2
[[-0.89425394 -0.41274405 0.17306717]
[ 0.44739469 -0.83489191 0.32061424]
[ 0.01216076 0.36413988 0.93126487]]
end 2
[[-0.82989029 -0.5243296 0.19068451]
[ 0.55780287 -0.77254528 0.30336405]
[-0.01175033 0.35812325 0.93360038]]
init 3
[[-0.80110215 0.35867688 -0.47915157]
[ 0.21217589 0.91874468 0.33300092]
[ 0.55965769 0.16510335 -0.81211092]]
end 3
[[ 0.86466244 -0.048639 0.49999311]
[ 0.09153374 0.99389435 -0.06160838]
[-0.49394375 0.09903669 0.86383523]]
init 4
[[-0.55059322 0.73411648 0.39739161]
[-0.81460012 -0.36847745 -0.44794086]
[-0.18241092 -0.57034846 0.80089256]]
end 4
[[ 0.99731877 -0.06860368 -0.02547166]
[ 0.06846451 0.99763368 -0.0062972 ]
[ 0.0258434 0.00453641 0.99965571]]
init 5
[[ 0.0411211 0.88763312 -0.45871178]
[-0.7350467 -0.28409247 -0.61562799]
[-0.67676835 0.36248988 0.6407696 ]]
end 5
[[-0.1254555 -0.6334718 -0.7635276 ]
[ 0.81256556 -0.50716522 0.28726407]
[-0.56920833 -0.58437737 0.57836404]]
次に、glorot_normalで確認してみます。
5回実行します。
for i in range(5):
glorot = glorot_normal(3,3)
x = glorot
print("init ", i+1)
print(x)
for t in range(28):
x = np.dot(glorot, x)
print("end ", i+1)
print(x)
1回目は、かなり大きな値になりました。
2、3回目は、小さな値になりました。
4、5回目はまずまずの値になりました。
乱数の初期値によっては、勾配爆発、勾配消失が発生する可能性があります。
init 1
[[ 0.36042833 0.04654217 -0.27553851]
[-0.1003273 0.99703076 -1.2884802 ]
[-0.48491302 -0.78957295 1.07022627]]
end 1
[[ 4.27113723e+07 1.31897155e+08 -1.79150247e+08]
[ 2.60558339e+08 8.04631220e+08 -1.09289607e+09]
[-2.23244480e+08 -6.89402146e+08 9.36385362e+08]]
init 2
[[-0.54241806 0.13369207 -0.32437354]
[ 0.50866409 -0.82871844 -0.16365881]
[ 0.83189728 -0.44774997 -0.67062721]]
end 2
[[-9.18237793e-04 8.41680925e-04 -5.67926628e-04]
[-4.33933433e-04 7.32992609e-04 -8.52003416e-04]
[-4.22035207e-05 6.39512261e-04 -1.06920879e-03]]
init 3
[[-0.23330056 0.48466714 0.91377187]
[-1.65115943 0.01144839 -0.56982934]
[ 0.71582823 0.12634886 0.10495938]]
end 3
[[-3.59039953e-05 -2.56355152e-05 -5.89428986e-05]
[ 8.69562293e-05 -6.92291595e-05 -1.05409878e-05]
[-4.59737230e-05 2.65928349e-06 -3.61461589e-05]]
init 4
[[ 0.74781561 -0.14005785 -0.51149799]
[-0.34702255 -0.63936539 0.45400338]
[-0.14042456 -0.64410887 -0.76022184]]
end 4
[[ 0.00129885 0.00275162 0.00203506]
[ 0.00096465 0.00442726 -0.0051897 ]
[ 0.00216138 0.00916843 0.00627747]]
init 5
[[ 0.03147282 -0.77472248 0.06427335]
[-0.38427955 0.57295038 0.24094615]
[-0.31341085 0.73580383 0.05397258]]
end 5
[[ 1.25692346 -2.88591353 -0.60767576]
[-1.8194483 4.17747828 0.87963559]
[-1.71327251 3.93369721 0.82830355]]
RNNにて検証
orthogonalは、$ Wh $のみに適用します。
ディープラーニングを実装から学ぶ(10-1)RNNの実装(RNN,LSTM,GRU) シンプルRNN 実行例のプログラムに初期値関数をorthogonalを設定し試してみます。
model = create_model((28,28))
model = add_layer(model, "RNN1", RNN, 100, Wh_init_func=orthogonal)
model = add_layer(model, "affine", affine, 10)
model = set_output(model, softmax)
model = set_error(model, cross_entropy_error)
optimizer = create_optimizer(SGD, lr=0.01)
epoch = 30
batch_size = 100
np.random.seed(10)
model, optimizer, learn_info = learn(model, nx_train, t_train, nx_test, t_test, batch_size=batch_size, epoch=epoch, optimizer=optimizer)
結果を比較します。
$ Wh $の初期値関数 | 学習正解 | テスト正解 | テスト最高 |
---|---|---|---|
glorot_normal | 97.86 | 97.35 | 97.35 |
orthogonal | 98.27 | 97.34 | 97.54 |
30エポック後のテスト正解率は、ほぼ同等でしたが、途中の正解率や誤差からorthogonalが良さそうなため、今後は、$ Wh $の初期値関数の既定値は、orthogonalにします。
双方向RNN
MNISTを上部からピクセルごとに時系列と見立てて予測してきました。
時系列のよっては、逆順(新しい方)から確認した方がよい場合があります。
数字では、上から確認した方がより特徴を捉えられるか、下から確認した方がより特徴が捉えられるか、どちらでしょうかね。
そこで、両方向から確認する方法が双方向RNNです。
双方向シンプルRNN
順伝播
RNNを順方向と逆方向から実行します。順方向と逆方向の結果に何某かの演算を行います。
演算の方法としては、結合、合計、平均、乗算などがあります。
RNNは、今までのシンプルRNNをそのまま利用します。$ u $の次元は、(n:バッチサイズ, t:時刻, d:説明変数の数)です。順方向は、$ u $を、逆方向は、$ u[:, ::-1] $を渡します。2つめのtの次元を逆に並び替えます。最後に、両方向の結果を演算します。
今回は、結合と合計の演算を実装します。
-
結合
両方向の結果$ fh_n $と$ rh_n $を結合します。numpyのconcatenate関数を利用します。結合するため、最終的な出力の次元は2倍になります。 -
合計
両方向の結果$ fh_n $と$ rh_n $を合計します。numpyのsum関数を利用します。
双方向シンプルRNNのプログラムです。
パラメータとして、mergeを追加しています。"concat"の場合、結合、"sum"の場合、合計とします。
重みなどの変数は、辞書型にしています。"f"が順方向、"r"が逆方向を示します。
例えば、Wx["f"]は、順方向の重み、Wx["r"]は、逆方向の重みです。同様に、h_n["f"]、hs["f"]は、順方向、h_n["r"]、hs["r"]は、逆方向です。
順方向、逆方向のRNNを呼び出し、最後に、h_n、hsの両方向を結合または合計します。
注意点です。逆方向は、時刻を逆転し実行しているため、再度、時刻の次元を反転させます。その後に、結合または合計を行います。
def BiRNN(u, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat"):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
# 初期化
h_n, hs = {}, {}
# 順方向
h_n["f"], hs["f"] = RNN(u_dr["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
# 逆方向
h_n["r"], hs["r"] = RNN(u_dr["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
# h_n - (n, h)
# hs - (n, t, h)
if merge == "concat":
# 結合
h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
elif merge == "sum":
# 合計
h_n_merge = h_n["f"] + h_n["r"]
hs_merge = hs["f"] + hs["r"][:, ::-1]
return h_n_merge, hs_merge, h_n, hs
逆伝播
逆伝播もシンプルRNNと同様に行います。結合、合計部分について補足します。
-
結合
順方向、逆方向それぞれそのまま後ろに流しているため、後ろからの勾配をそのまま流しています。 -
合計
足し算の勾配は、1でした。それぞれ後ろからの勾配をそのまま流します。
注意点です。順伝播と同様に、逆方向は時刻の次元を反転させます。
duの方向は、分岐のため、順伝播、逆伝播の両方の勾配を加えます。ここでも時刻の次元の反転を忘れないようにします。
def BiRNN_back(dh_n_merge, dhs_merge, u, hs, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat"):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
h = Wh["f"].shape[0]
# dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
# dhs - (n, t, h)
dh_n, dhs = {}, {}
if merge == "concat":
# 結合
dh_n["f"] = dh_n_merge[:, :h]
dh_n["r"] = dh_n_merge[:, h:]
dhs["f"] = dhs_merge[:, :, :h]
dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
elif merge == "sum":
# 合計
dh_n["f"] = dh_n_merge
dh_n["r"] = dh_n_merge
dhs["f"] = dhs_merge
dhs["r"] = dhs_merge[:, ::-1]
# 初期化
du = {}
dWx, dWh, db = {}, {}, {}
dh_0 = {}
# 順方向
du["f"], dWx["f"], dWh["f"], db["f"], dh_0["f"] = RNN_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
# 逆方向
du["r"], dWx["r"], dWh["r"], db["r"], dh_0["r"] = RNN_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
# du - (n, t, d)
du_merge = du["f"] + du["r"][:, ::-1]
return du_merge, dWx, dWh, db, dh_0
多階層
双方向RNNを2階層にした場合を考えます。
以下の図のように、一層ごと出力の演算(結合または合計)を行うことになります。(図は時刻をまとめる形で簡略化して描いています。)この方法もあるかもしれませんが、順方向、逆方向それぞれを独立に多階層化し最後に演算を行うことにします。
以下の構造にも対応できるように実装を変更していきます。
この図の1層目は、順方向、逆方向それぞれ独立に返却します。1層目は、出力時の演算は行いません。
出力方法(merge)として"dictionary"を追加します。"dictionary"の場合は、順方向、逆方向それぞれを辞書に格納します。
2層目の入力は辞書になります。そのためinput_dictionaryパラメータを追加します。Trueの場合、辞書を入力として受け取ります。
逆伝播時にも同様の変更を行います。
def BiRNN(u, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
# 初期化
h_n, hs = {}, {}
# 順方向
h_n["f"], hs["f"] = RNN(u_dr["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
# 逆方向
h_n["r"], hs["r"] = RNN(u_dr["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
# h_n - (n, h)
# hs - (n, t, h)
if merge == "concat":
# 結合
h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
elif merge == "sum":
# 合計
h_n_merge = h_n["f"] + h_n["r"]
hs_merge = hs["f"] + hs["r"][:, ::-1]
elif merge == "dictionary":
# 辞書
h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
return h_n_merge, hs_merge, h_n, hs
def BiRNN_back(dh_n_merge, dhs_merge, u, hs, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
h = Wh["f"].shape[0]
# dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
# dhs - (n, t, h)
dh_n, dhs = {}, {}
if merge == "concat":
# 結合
dh_n["f"] = dh_n_merge[:, :h]
dh_n["r"] = dh_n_merge[:, h:]
dhs["f"] = dhs_merge[:, :, :h]
dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
elif merge == "sum":
# 合計
dh_n["f"] = dh_n_merge
dh_n["r"] = dh_n_merge
dhs["f"] = dhs_merge
dhs["r"] = dhs_merge[:, ::-1]
elif merge == "dictionary":
# 辞書
dh_n["f"] = dh_n_merge["f"]
dh_n["r"] = dh_n_merge["r"]
dhs["f"] = dhs_merge["f"]
dhs["r"] = dhs_merge["r"][:, ::-1]
# 初期化
du = {}
dWx, dWh, db = {}, {}, {}
dh_0 = {}
# 順方向
du["f"], dWx["f"], dWh["f"], db["f"], dh_0["f"] = RNN_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
# 逆方向
du["r"], dWx["r"], dWh["r"], db["r"], dh_0["r"] = RNN_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
# du - (n, t, d)
if not input_dictionary:
du_merge = du["f"] + du["r"][:, ::-1]
else:
du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
return du_merge, dWx, dWh, db, dh_0
フレームワーク対応
「ディープラーニングを実装から学ぶ(8)実装変更」に双方向RNNを追加します。
初期化
シンプルRNNと基本同じです。違いは、順方向、逆方向それぞれの重さ、バイアスを辞書に格納することです。
もう1点、演算が"concat"の場合は、各出力を結合するため出力の次元が2倍になります。
$ Wh $の初期値は、orthogonalに変更しました。
def BiRNN_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
t, d = d_prev
h_next = h
# concatの場合、出力は2倍
if merge == "concat":
h_next = h * 2
d_next = (t, h_next)
if not params.get("return_hs"):
d_next = h_next
Wx, Wh, b = {}, {}, {}
for dr in ["f", "r"]:
Wx[dr] = Wx_init_func(d, h, **Wx_init_params)
Wh[dr] = Wh_init_func(h, h, **Wh_init_params)
b[dr] = bias_init_func(h, **bias_init_params)
return d_next, {"Wx":Wx, "Wh":Wh, "b":b}
def BiRNN_init_optimizer():
sWx, sWh, sb = {}, {}, {}
for dr in ["f", "r"]:
sWx[dr] = {}
sWh[dr] = {}
sb[dr] = {}
return {"sWx":sWx, "sWh":sWh, "sb":sb}
順伝播
順伝播を実行します。基本的に各変数を順方向、逆方向の辞書に対応するのみです。
def BiRNN_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
h_0[dr] = np.zeros((u.shape[0], weights["Wh"][dr].shape[0]))
else:
h_0[dr] = np.zeros((u[dr].shape[0], weights["Wh"][dr].shape[0]))
# RNN
h_n_merge, hs_merge, h_n, hs = func(u, h_0, weights["Wx"], weights["Wh"], weights["b"], propagation_func, merge, input_dictionary)
# RNN最下層以外
if params.get("return_hs"):
z = hs_merge
# RNN最下層
else:
z = h_n_merge
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
for dr in ["f", "r"]:
weight_decay_r += weight_decay["func"](weights["Wx"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wh"][dr], **weight_decay["params"])
return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs}, weight_decay_r
逆伝播
逆伝播を実行します。基本的に各変数を順方向、逆方向の辞書に対応するのみです。
def BiRNN_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):
# RNN最下層以外
if return_hs:
if merge != "dictionary":
dh_n_merge = np.zeros_like(us["h_n_merge"])
else:
dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
dhs_merge = dz
# RNN最下層
else:
dh_n_merge = dz
if merge != "dictionary":
dhs_merge = np.zeros_like(us["hs_merge"])
else:
dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
# RNNの勾配計算
du, dWx, dWh, db, dh_0 = back_func(dh_n_merge, dhs_merge, us["u"], us["hs"], us["h_0"], weights["Wx"], weights["Wh"], weights["b"], propagation_func, merge, input_dictionary)
# 重み減衰対応
if weight_decay is not None:
for dr in ["f", "r"]:
dWx[dr] += weight_decay["back_func"](weights["Wx"][dr], **weight_decay["params"])
dWh[dr] += weight_decay["back_func"](weights["Wh"][dr], **weight_decay["params"])
return {"du":du, "dWx":dWx, "dWh":dWh, "db":db}
重みの更新
順方向、逆方向それぞれの重み、バイアスを更新します。
def BiRNN_update_weight(func, du, weights, optimizer_stats, **params):
for dr in ["f", "r"]:
weights["Wx"][dr], optimizer_stats["sWx"][dr] = func(weights["Wx"][dr], du["dWx"][dr], **params, **optimizer_stats["sWx"][dr])
weights["Wh"][dr], optimizer_stats["sWh"][dr] = func(weights["Wh"][dr], du["dWh"][dr], **params, **optimizer_stats["sWh"][dr])
weights["b"][dr], optimizer_stats["sb"][dr] = func(weights["b"][dr], du["db"][dr], **params, **optimizer_stats["sb"][dr])
return weights, optimizer_stats
実行例
「ディープラーニングを実装から学ぶ(8)実装変更」の「参考」-「プログラム全体」のプログラムを事前に実行しておきます。
MNISTのデータを読み込みます。(MNISTのデータファイルが、c:\mnistに格納されている例)
x_train, t_train, x_test, t_test = load_mnist('c:\\mnist\\')
data_normalizer = create_data_normalizer(min_max)
nx_train, data_normalizer_stats = train_data_normalize(data_normalizer, x_train)
nx_test = test_data_normalize(data_normalizer, x_test)
時系列データとなるようにreshapeします。
nx_train = nx_train.reshape((nx_train.shape[0], 28, 28))
nx_test = nx_test.reshape((nx_test.shape[0], 28, 28))
モデルを定義します。出力次元数は100にします。
双方向RNNの1層です。演算は、合計("sum")とします。
model = create_model((28,28))
model = add_layer(model, "BiRNN1", BiRNN, 100, merge="sum")
model = add_layer(model, "affine", affine, 10)
model = set_output(model, softmax)
model = set_error(model, cross_entropy_error)
optimizer = create_optimizer(SGD, lr=0.01)
30エポック実行します。
epoch = 30
batch_size = 100
np.random.seed(10)
model, optimizer, learn_info = learn(model, nx_train, t_train, nx_test, t_test, batch_size=batch_size, epoch=epoch, optimizer=optimizer)
結果です。
input - 0 (28, 28)
BiRNN1 BiRNN (28, 28) 100
affine affine 100 10
output softmax 10
error cross_entropy_error
0 0.08276666666666667 2.4694932161948393 0.0807 2.467711539524275
1 0.7976333333333333 0.7224299245059749 0.9011 0.36007855521236326
2 0.9129 0.3071644615683738 0.9297 0.24448251744407254
3 0.933 0.23339284779792896 0.9418 0.20460460055846733
4 0.9428833333333333 0.19330278050141367 0.9452 0.18235088618707465
5 0.9502833333333334 0.16817265239467252 0.9556 0.14892486442113187
6 0.9550666666666666 0.1496513755834529 0.9563 0.14608596831733958
7 0.9593833333333334 0.1357794529971745 0.9584 0.13615785606170594
8 0.9623666666666667 0.12634386000411943 0.961 0.13028779104323354
9 0.96525 0.11460733561373888 0.9657 0.11348966105739382
10 0.9682333333333333 0.1079471845883365 0.9668 0.10858853202882413
11 0.9692333333333333 0.10160679157369244 0.9691 0.10127266935092118
12 0.97065 0.09646661669902329 0.9707 0.09741699781003078
13 0.9738833333333333 0.08874902808825438 0.9709 0.0974751791352071
14 0.97465 0.08513319550104065 0.9736 0.08820080097152737
15 0.9753 0.08127818526156802 0.9723 0.08705366351988436
16 0.9768333333333333 0.07730725580142364 0.9758 0.08434469601550207
17 0.9783 0.07244720897536845 0.9762 0.08135833169989594
18 0.9787333333333333 0.07012015726493016 0.9693 0.10346379775865691
19 0.9800333333333333 0.06777993291484356 0.9765 0.07907273898044402
20 0.98015 0.0648363240542574 0.9758 0.07887806162399832
21 0.98125 0.06262620142131756 0.9773 0.07744372451315322
22 0.9823 0.05932344435748826 0.9678 0.10134595089422335
23 0.9833 0.057475974001274625 0.9749 0.08343724319233303
24 0.9832333333333333 0.05571428411124262 0.9761 0.07660213665776289
25 0.9844 0.05324665768972804 0.9781 0.0747334290192525
26 0.9855166666666667 0.05124096084976474 0.9789 0.06781155342892754
27 0.9848833333333333 0.050314958583714584 0.978 0.07235791581025655
28 0.9863166666666666 0.04735377729986719 0.9787 0.06785586351774157
29 0.9861333333333333 0.04710562378222973 0.9779 0.07280146686467145
30 0.9862833333333333 0.045595901126329566 0.9793 0.06660518314597848
所要時間 = 10 分 27 秒
所要時間は、当然ですが、RNNの約2倍になりました。
双方向RNNと通常のRNNの結果を表にしました。双方向RNNの方が精度が向上しております。
種別 | 学習正解 | テスト正解 | テスト最高 |
---|---|---|---|
双方向RNN(sum) | 98.63 | 97.93 | 97.93 |
双方向RNN(concat) | 98.65 | 97.80 | 97.80 |
RNN(順方向) | 98.27 | 97.34 | 97.54 |
RNN(逆方向) | 98.30 | 97.46 | 97.80 |
双方向LSTM
双方向LSTMも双方向シンプルRNN同様に、重みなどを辞書化し、順方向、逆方向のデータを保持します。同様の対応を行います。
def BiLSTM(u, h_0, c_0, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
# 初期化
h_n, hs, c_n, cs = {}, {}, {}, {}
fs, Is, gs, os = {}, {}, {}, {}
# 順方向
h_n["f"], hs["f"], c_n["f"], cs["f"], fs["f"], Is["f"], gs["f"], os["f"] = LSTM(u_dr["f"], h_0["f"], c_0["f"], Wxf["f"], Whf["f"], bf["f"], Wxi["f"], Whi["f"], bi["f"], Wxg["f"], Whg["f"], bg["f"], Wxo["f"], Who["f"], bo["f"], propagation_func)
# 逆方向
h_n["r"], hs["r"], c_n["r"], cs["r"], fs["r"], Is["r"], gs["r"], os["r"] = LSTM(u_dr["r"], h_0["r"], c_0["r"], Wxf["r"], Whf["r"], bf["r"], Wxi["r"], Whi["r"], bi["r"], Wxg["r"], Whg["r"], bg["r"], Wxo["r"], Who["r"], bo["r"], propagation_func)
# h_n, c_n - (n, h)
# hs, cs - (n, t, h)
if merge == "concat":
# 結合
h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
elif merge == "sum":
# 合計
h_n_merge = h_n["f"] + h_n["r"]
hs_merge = hs["f"] + hs["r"][:, ::-1]
elif merge == "dictionary":
# 辞書
h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
return h_n_merge, hs_merge, h_n, hs, c_n, cs, fs, Is, gs, os
def BiLSTM_back(dh_n_merge, dhs_merge, u, hs, h_0, cs, c_0, fs, Is, gs, os, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
h = Whf["f"].shape[0]
# dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
# dhs - (n, t, h)
dh_n, dhs = {}, {}
if merge == "concat":
# 結合
dh_n["f"] = dh_n_merge[:, :h]
dh_n["r"] = dh_n_merge[:, h:]
dhs["f"] = dhs_merge[:, :, :h]
dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
elif merge == "sum":
# 合計
dh_n["f"] = dh_n_merge
dh_n["r"] = dh_n_merge
dhs["f"] = dhs_merge
dhs["r"] = dhs_merge[:, ::-1]
elif merge == "dictionary":
# 辞書
dh_n["f"] = dh_n_merge["f"]
dh_n["r"] = dh_n_merge["r"]
dhs["f"] = dhs_merge["f"]
dhs["r"] = dhs_merge["r"][:, ::-1]
# 初期化
du = {}
dWxf, dWhf, dbf = {}, {}, {}
dWxi, dWhi, dbi = {}, {}, {}
dWxg, dWhg, dbg = {}, {}, {}
dWxo, dWho, dbo = {}, {}, {}
dh_0, dc_0 = {}, {}
# 順方向
du["f"], dWxf["f"], dWhf["f"], dbf["f"], dWxi["f"], dWhi["f"], dbi["f"], dWxg["f"], dWhg["f"], dbg["f"], dWxo["f"], dWho["f"], dbo["f"], dh_0["f"], dc_0["f"] = \
LSTM_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], cs["f"], c_0["f"], fs["f"], Is["f"], gs["f"], os["f"], Wxf["f"], Whf["f"], bf["f"], Wxi["f"], Whi["f"], bi["f"], Wxg["f"], Whg["f"], bg["f"], Wxo["f"], Who["f"], bo["f"], propagation_func)
# 逆方向
du["r"], dWxf["r"], dWhf["r"], dbf["r"], dWxi["r"], dWhi["r"], dbi["r"], dWxg["r"], dWhg["r"], dbg["r"], dWxo["r"], dWho["r"], dbo["r"], dh_0["r"], dc_0["r"] = \
LSTM_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], cs["r"], c_0["r"], fs["r"], Is["r"], gs["r"], os["r"], Wxf["r"], Whf["r"], bf["r"], Wxi["r"], Whi["r"], bi["r"], Wxg["r"], Whg["r"], bg["r"], Wxo["r"], Who["r"], bo["r"], propagation_func)
# du - (n, t, d)
if not input_dictionary:
du_merge = du["f"] + du["r"][:, ::-1]
else:
du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
return du_merge, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0
def BiLSTM_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
t, d = d_prev
h_next = h
# concatの場合、出力は2倍
if merge == "concat":
h_next = h * 2
d_next = (t, h_next)
if not params.get("return_hs"):
d_next = h_next
Wxf, Whf, bf = {}, {}, {}
Wxi, Whi, bi = {}, {}, {}
Wxg, Whg, bg = {}, {}, {}
Wxo, Who, bo = {}, {}, {}
for dr in ["f", "r"]:
Wxf[dr] = Wx_init_func(d, h, **Wx_init_params)
Whf[dr] = Wh_init_func(h, h, **Wh_init_params)
bf[dr] = bias_init_func(h, **bias_init_params)
Wxi[dr] = Wx_init_func(d, h, **Wx_init_params)
Whi[dr] = Wh_init_func(h, h, **Wh_init_params)
bi[dr] = bias_init_func(h, **bias_init_params)
Wxg[dr] = Wx_init_func(d, h, **Wx_init_params)
Whg[dr] = Wh_init_func(h, h, **Wh_init_params)
bg[dr] = bias_init_func(h, **bias_init_params)
Wxo[dr] = Wx_init_func(d, h, **Wx_init_params)
Who[dr] = Wh_init_func(h, h, **Wh_init_params)
bo[dr] = bias_init_func(h, **bias_init_params)
return d_next, {"Wxf":Wxf, "Whf":Whf, "bf":bf, "Wxi":Wxi, "Whi":Whi, "bi":bi, "Wxg":Wxg, "Whg":Whg, "bg":bg, "Wxo":Wxo, "Who":Who, "bo":bo}
def BiLSTM_init_optimizer():
sWxf, sWhf, sbf = {}, {}, {}
sWxi, sWhi, sbi = {}, {}, {}
sWxg, sWhg, sbg = {}, {}, {}
sWxo, sWho, sbo = {}, {}, {}
for dr in ["f", "r"]:
sWxf[dr] = {}
sWhf[dr] = {}
sbf[dr] = {}
sWxi[dr] = {}
sWhi[dr] = {}
sbi[dr] = {}
sWxg[dr] = {}
sWhg[dr] = {}
sbg[dr] = {}
sWxo[dr] = {}
sWho[dr] = {}
sbo[dr] = {}
return {"sWxf":sWxf, "sWhf":sWhf, "sbf":sbf,
"sWxi":sWxi, "sWhi":sWhi, "sbi":sbi,
"sWxg":sWxg, "sWhg":sWhg, "sbg":sbg,
"sWxo":sWxo, "sWho":sWho, "sbo":sbo}
def BiLSTM_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
h_0[dr] = np.zeros((u.shape[0], weights["Whf"][dr].shape[0]))
else:
h_0[dr] = np.zeros((u[dr].shape[0], weights["Whf"][dr].shape[0]))
c_0 = params.get("c_0")
if c_0 is None:
c_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
c_0[dr] = np.zeros((u.shape[0], weights["Whf"][dr].shape[0]))
else:
c_0[dr] = np.zeros((u[dr].shape[0], weights["Whf"][dr].shape[0]))
# LSTM
h_n_merge, hs_merge, h_n, hs, c_n, cs, fs, Is, gs, os = func(u, h_0, c_0,
weights["Wxf"], weights["Whf"], weights["bf"],
weights["Wxi"], weights["Whi"], weights["bi"],
weights["Wxg"], weights["Whg"], weights["bg"],
weights["Wxo"], weights["Who"], weights["bo"],
propagation_func, merge, input_dictionary)
# LSTM最下層以外
if params.get("return_hs"):
z = hs_merge
# LSTM最下層
else:
z = h_n_merge
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
for dr in ["f", "r"]:
weight_decay_r += weight_decay["func"](weights["Wxf"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whf"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxi"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whi"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxg"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whg"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxo"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Who"][dr], **weight_decay["params"])
return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs, "c_0":c_0, "c_n":c_n, "cs":cs,
"fs":fs, "Is":Is, "gs":gs, "os":os}, weight_decay_r
def BiLSTM_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):
# LSTM最下層以外
if return_hs:
if merge != "dictionary":
dh_n_merge = np.zeros_like(us["h_n_merge"])
else:
dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
dhs_merge = dz
# LSTM最下層
else:
dh_n_merge = dz
if merge != "dictionary":
dhs_merge = np.zeros_like(us["hs_merge"])
else:
dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
# LSTMの勾配計算
du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0 = back_func(dh_n_merge, dhs_merge,
us["u"], us["hs"], us["h_0"], us["cs"], us["c_0"],
us["fs"], us["Is"], us["gs"], us["os"],
weights["Wxf"], weights["Whf"], weights["bf"],
weights["Wxi"], weights["Whi"], weights["bi"],
weights["Wxg"], weights["Whg"], weights["bg"],
weights["Wxo"], weights["Who"], weights["bo"],
propagation_func, merge, input_dictionary)
# 重み減衰対応
if weight_decay is not None:
for dr in ["f", "r"]:
dWxf[dr] += weight_decay["back_func"](weights["Wxf"][dr], **weight_decay["params"])
dWhf[dr] += weight_decay["back_func"](weights["Whf"][dr], **weight_decay["params"])
dWxi[dr] += weight_decay["back_func"](weights["Wxi"][dr], **weight_decay["params"])
dWhi[dr] += weight_decay["back_func"](weights["Whi"][dr], **weight_decay["params"])
dWxg[dr] += weight_decay["back_func"](weights["Wxg"][dr], **weight_decay["params"])
dWhg[dr] += weight_decay["back_func"](weights["Whg"][dr], **weight_decay["params"])
dWxo[dr] += weight_decay["back_func"](weights["Wxo"][dr], **weight_decay["params"])
dWho[dr] += weight_decay["back_func"](weights["Who"][dr], **weight_decay["params"])
return {"du":du, "dWxf":dWxf, "dWhf":dWhf, "dbf":dbf, "dWxi":dWxi, "dWhi":dWhi, "dbi":dbi, "dWxg":dWxg, "dWhg":dWhg, "dbg":dbg, "dWxo":dWxo, "dWho":dWho, "dbo":dbo}
def BiLSTM_update_weight(func, du, weights, optimizer_stats, **params):
for dr in ["f", "r"]:
weights["Wxf"][dr], optimizer_stats["sWxf"][dr] = func(weights["Wxf"][dr], du["dWxf"][dr], **params, **optimizer_stats["sWxf"][dr])
weights["Whf"][dr], optimizer_stats["sWhf"][dr] = func(weights["Whf"][dr], du["dWhf"][dr], **params, **optimizer_stats["sWhf"][dr])
weights["bf"][dr], optimizer_stats["sbf"][dr] = func(weights["bf"][dr], du["dbf"][dr], **params, **optimizer_stats["sbf"][dr])
weights["Wxi"][dr], optimizer_stats["sWxi"][dr] = func(weights["Wxi"][dr], du["dWxi"][dr], **params, **optimizer_stats["sWxi"][dr])
weights["Whi"][dr], optimizer_stats["sWhi"][dr] = func(weights["Whi"][dr], du["dWhi"][dr], **params, **optimizer_stats["sWhi"][dr])
weights["bi"][dr], optimizer_stats["sbi"][dr] = func(weights["bi"][dr], du["dbi"][dr], **params, **optimizer_stats["sbi"][dr])
weights["Wxg"][dr], optimizer_stats["sWxg"][dr] = func(weights["Wxg"][dr], du["dWxg"][dr], **params, **optimizer_stats["sWxg"][dr])
weights["Whg"][dr], optimizer_stats["sWhg"][dr] = func(weights["Whg"][dr], du["dWhg"][dr], **params, **optimizer_stats["sWhg"][dr])
weights["bg"][dr], optimizer_stats["sbg"][dr] = func(weights["bg"][dr], du["dbg"][dr], **params, **optimizer_stats["sbg"][dr])
weights["Wxo"][dr], optimizer_stats["sWxo"][dr] = func(weights["Wxo"][dr], du["dWxo"][dr], **params, **optimizer_stats["sWxo"][dr])
weights["Who"][dr], optimizer_stats["sWho"][dr] = func(weights["Who"][dr], du["dWho"][dr], **params, **optimizer_stats["sWho"][dr])
weights["bo"][dr], optimizer_stats["sbo"][dr] = func(weights["bo"][dr], du["dbo"][dr], **params, **optimizer_stats["sbo"][dr])
return weights, optimizer_stats
双方向GRU
双方向GRUも双方向シンプルRNN、双方向LSTM同様に、重みなどを辞書化し、順方向、逆方向のデータを保持します。同様の対応を行います。
def BiGRU(u, h_0, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
# 初期化
h_n, hs = {}, {}
zs, rs, ns = {}, {}, {}
# 順方向
h_n["f"], hs["f"], zs["f"], rs["f"], ns["f"] = GRU(u_dr["f"], h_0["f"], Wxz["f"], Whz["f"], bz["f"], Wxr["f"], Whr["f"], br["f"], Wxn["f"], Whn["f"], bn["f"], propagation_func)
# 逆方向
h_n["r"], hs["r"], zs["r"], rs["r"], ns["r"] = GRU(u_dr["r"], h_0["r"], Wxz["r"], Whz["r"], bz["r"], Wxr["r"], Whr["r"], br["r"], Wxn["r"], Whn["r"], bn["r"], propagation_func)
# h_n - (n, h)
# hs - (n, t, h)
if merge == "concat":
# 結合
h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
elif merge == "sum":
# 合計
h_n_merge = h_n["f"] + h_n["r"]
hs_merge = hs["f"] + hs["r"][:, ::-1]
elif merge == "dictionary":
# 辞書
h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
return h_n_merge, hs_merge, h_n, hs, zs, rs, ns
def BiGRU_back(dh_n_merge, dhs_merge, u, hs, h_0, zs, rs, ns, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
h = Whz["f"].shape[0]
# dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
# dhs - (n, t, h)
dh_n, dhs = {}, {}
if merge == "concat":
# 結合
dh_n["f"] = dh_n_merge[:, :h]
dh_n["r"] = dh_n_merge[:, h:]
dhs["f"] = dhs_merge[:, :, :h]
dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
elif merge == "sum":
# 合計
dh_n["f"] = dh_n_merge
dh_n["r"] = dh_n_merge
dhs["f"] = dhs_merge
dhs["r"] = dhs_merge[:, ::-1]
elif merge == "dictionary":
# 辞書
dh_n["f"] = dh_n_merge["f"]
dh_n["r"] = dh_n_merge["r"]
dhs["f"] = dhs_merge["f"]
dhs["r"] = dhs_merge["r"][:, ::-1]
# 初期化
du = {}
dWxz, dWhz, dbz = {}, {}, {}
dWxr, dWhr, dbr = {}, {}, {}
dWxn, dWhn, dbn = {}, {}, {}
dh_0 = {}
# 順方向
du["f"], dWxz["f"], dWhz["f"], dbz["f"], dWxr["f"], dWhr["f"], dbr["f"], dWxn["f"], dWhn["f"], dbn["f"], dh_0["f"] = \
GRU_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], zs["f"], rs["f"], ns["f"], Wxz["f"], Whz["f"], bz["f"], Wxr["f"], Whr["f"], br["f"], Wxn["f"], Whn["f"], bn["f"], propagation_func)
# 逆方向
du["r"], dWxz["r"], dWhz["r"], dbz["r"], dWxr["r"], dWhr["r"], dbr["r"], dWxn["r"], dWhn["r"], dbn["r"], dh_0["r"] = \
GRU_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], zs["r"], rs["r"], ns["r"], Wxz["r"], Whz["r"], bz["r"], Wxr["r"], Whr["r"], br["r"], Wxn["r"], Whn["r"], bn["r"], propagation_func)
# du - (n, t, d)
if not input_dictionary:
du_merge = du["f"] + du["r"][:, ::-1]
else:
du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
return du_merge, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0
def BiGRU_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
t, d = d_prev
h_next = h
# concatの場合、出力は2倍
if merge == "concat":
h_next = h * 2
d_next = (t, h_next)
if not params.get("return_hs"):
d_next = h_next
Wxz, Whz, bz = {}, {}, {}
Wxr, Whr, br = {}, {}, {}
Wxn, Whn, bn = {}, {}, {}
for dr in ["f", "r"]:
Wxz[dr] = Wx_init_func(d, h, **Wx_init_params)
Whz[dr] = Wh_init_func(h, h, **Wh_init_params)
bz[dr] = bias_init_func(h, **bias_init_params)
Wxr[dr] = Wx_init_func(d, h, **Wx_init_params)
Whr[dr] = Wh_init_func(h, h, **Wh_init_params)
br[dr] = bias_init_func(h, **bias_init_params)
Wxn[dr] = Wx_init_func(d, h, **Wx_init_params)
Whn[dr] = Wh_init_func(h, h, **Wh_init_params)
bn[dr] = bias_init_func(h, **bias_init_params)
return d_next, {"Wxz":Wxz, "Whz":Whz, "bz":bz, "Wxr":Wxr, "Whr":Whr, "br":br, "Wxn":Wxn, "Whn":Whn, "bn":bn}
def BiGRU_init_optimizer():
sWxz, sWhz, sbz = {}, {}, {}
sWxr, sWhr, sbr = {}, {}, {}
sWxn, sWhn, sbn = {}, {}, {}
for dr in ["f", "r"]:
sWxz[dr] = {}
sWhz[dr] = {}
sbz[dr] = {}
sWxr[dr] = {}
sWhr[dr] = {}
sbr[dr] = {}
sWxn[dr] = {}
sWhn[dr] = {}
sbn[dr] = {}
return {"sWxz":sWxz, "sWhz":sWhz, "sbz":sbz,
"sWxr":sWxr, "sWhr":sWhr, "sbr":sbr,
"sWxn":sWxn, "sWhn":sWhn, "sbn":sbn}
def BiGRU_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
h_0[dr] = np.zeros((u.shape[0], weights["Whz"][dr].shape[0]))
else:
h_0[dr] = np.zeros((u[dr].shape[0], weights["Whz"][dr].shape[0]))
c_0 = params.get("c_0")
if c_0 is None:
c_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
c_0[dr] = np.zeros((u.shape[0], weights["Whz"][dr].shape[0]))
else:
c_0[dr] = np.zeros((u[dr].shape[0], weights["Whz"][dr].shape[0]))
# GRU
h_n_merge, hs_merge, h_n, hs, zs, rs, ns = func(u, h_0,
weights["Wxz"], weights["Whz"], weights["bz"],
weights["Wxr"], weights["Whr"], weights["br"],
weights["Wxn"], weights["Whn"], weights["bn"],
propagation_func, merge, input_dictionary)
# GRU最下層以外
if params.get("return_hs"):
z = hs_merge
# GRU最下層
else:
z = h_n_merge
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
for dr in ["f", "r"]:
weight_decay_r += weight_decay["func"](weights["Wxz"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whz"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxr"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whr"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxn"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whn"][dr], **weight_decay["params"])
return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs, "zs":zs, "rs":rs, "ns":ns}, weight_decay_r
def BiGRU_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):
# GRU最下層以外
if return_hs:
if merge != "dictionary":
dh_n_merge = np.zeros_like(us["h_n_merge"])
else:
dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
dhs_merge = dz
# GRU最下層
else:
dh_n_merge = dz
if merge != "dictionary":
dhs_merge = np.zeros_like(us["hs_merge"])
else:
dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
# GRUの勾配計算
du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0 = back_func(dh_n_merge, dhs_merge,
us["u"], us["hs"], us["h_0"],
us["zs"], us["rs"], us["ns"],
weights["Wxz"], weights["Whz"], weights["bz"],
weights["Wxr"], weights["Whr"], weights["br"],
weights["Wxn"], weights["Whn"], weights["bn"],
propagation_func, merge, input_dictionary)
# 重み減衰対応
if weight_decay is not None:
for dr in ["f", "r"]:
dWxz[dr] += weight_decay["back_func"](weights["Wxz"][dr], **weight_decay["params"])
dWhz[dr] += weight_decay["back_func"](weights["Whz"][dr], **weight_decay["params"])
dWxr[dr] += weight_decay["back_func"](weights["Wxr"][dr], **weight_decay["params"])
dWhr[dr] += weight_decay["back_func"](weights["Whr"][dr], **weight_decay["params"])
dWxn[dr] += weight_decay["back_func"](weights["Wxn"][dr], **weight_decay["params"])
dWhn[dr] += weight_decay["back_func"](weights["Whn"][dr], **weight_decay["params"])
return {"du":du, "dWxz":dWxz, "dWhz":dWhz, "dbz":dbz, "dWxr":dWxr, "dWhr":dWhr, "dbr":dbr, "dWxn":dWxn, "dWhn":dWhn, "dbn":dbn}
def BiGRU_update_weight(func, du, weights, optimizer_stats, **params):
for dr in ["f", "r"]:
weights["Wxz"][dr], optimizer_stats["sWxz"][dr] = func(weights["Wxz"][dr], du["dWxz"][dr], **params, **optimizer_stats["sWxz"][dr])
weights["Whz"][dr], optimizer_stats["sWhz"][dr] = func(weights["Whz"][dr], du["dWhz"][dr], **params, **optimizer_stats["sWhz"][dr])
weights["bz"][dr], optimizer_stats["sbz"][dr] = func(weights["bz"][dr], du["dbz"][dr], **params, **optimizer_stats["sbz"][dr])
weights["Wxr"][dr], optimizer_stats["sWxr"][dr] = func(weights["Wxr"][dr], du["dWxr"][dr], **params, **optimizer_stats["sWxr"][dr])
weights["Whr"][dr], optimizer_stats["sWhr"][dr] = func(weights["Whr"][dr], du["dWhr"][dr], **params, **optimizer_stats["sWhr"][dr])
weights["br"][dr], optimizer_stats["sbr"][dr] = func(weights["br"][dr], du["dbr"][dr], **params, **optimizer_stats["sbr"][dr])
weights["Wxn"][dr], optimizer_stats["sWxn"][dr] = func(weights["Wxn"][dr], du["dWxn"][dr], **params, **optimizer_stats["sWxn"][dr])
weights["Whn"][dr], optimizer_stats["sWhn"][dr] = func(weights["Whn"][dr], du["dWhn"][dr], **params, **optimizer_stats["sWhn"][dr])
weights["bn"][dr], optimizer_stats["sbn"][dr] = func(weights["bn"][dr], du["dbn"][dr], **params, **optimizer_stats["sbn"][dr])
return weights, optimizer_stats
参考
双方向RNN・orthogonal(重みの初期値)全体のプログラムです。
また、RNNの重みの初期をorthogonalに変更したためRNNについても記載します。
関数仕様
# 層追加関数
model = add_layer(model, name, func, d=None, **kwargs)
# 引数
# model : モデル
# name : レイヤの名前
# func : 中間層の関数
# affine,sigmoid,tanh,relu,leaky_relu,prelu,rrelu,relun,srelu,elu,maxout,
# identity,softplus,softsign,step,dropout,batch_normalization,
# convolution2d,max_pooling2d,average_pooling2d,flatten2d,
# RNN,LSTM,GRU,BiRNN, BiLSTM,BiGRU
# d : ノード数
# affine,maxout,convolution2d,RNN,LSTM,GRU,BiRNN, BiLSTM,BiGRUの場合指定
# convolution2dの場合は、フィルタ数
# RNN,LSTM,GRU,BiRNN, BiLSTM,BiGRUの場合は、出力次元数
# kwargs: 中間層の関数のパラメータ
# affine - weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
# weight_init_func - lecun_normal,lecun_uniform,glorot_normal,glorot_uniform,he_normal,he_uniform,normal_w,uniform_w,zeros_w,ones_w
# weight_init_params
# normal_w - mean=0, var=1
# uniform_w - min=0, max=1
# bias_init_func - normal_b,uniform_b,zeros_b,ones_b
# bias_init_params
# normal_b - mean=0, var=1
# uniform_b - min=0, max=1
# leaky_relu - alpha
# rrelu - min=0.0, max=0.1
# relun - n
# srelu - alpha
# elu - alpha
# maxout - unit=1, weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
# dropout - dropout_ratio=0.9
# batch_normalization - batch_norm_node=True, use_gamma_beta=True, use_train_stats=True, alpha=0.1
# convolution2d - padding=0, strides=1, weight_init_func=he_normal, weight_init_params={}, bias_init_func=zeros_b, bias_init_params={}
# max_pooling2d - pool_size=2, padding=0, strides=None
# average_pooling2d - pool_size=2, padding=0, strides=None
# RNN, LSTM, GRU - propagation_func=tanh, return_hs=False, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}
# propagation_func - tanh,sigmoid,relu
# Wx_init_func, Wh_init_funcは、affineのweight_init_funcに同じ。Wh_init_funcは、orthogonalも指定可能
# BiRNN, BiLSTM, BiGRU - propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=glorot_normal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}
# propagation_func - tanh,sigmoid,relu
# merge - "concat", "sum", "dictionary"
# Wx_init_func, Wh_init_funcは、affineのweight_init_funcに同じ。Wh_init_funcは、orthogonalも指定可能
# 戻り値
# モデル
プログラム
def orthogonal(d_1, d):
rand = np.random.normal(0., 1., (d_1, d))
return np.linalg.svd(rand)[0]
def RNN(u, h_0, Wx, Wh, b, propagation_func=tanh):
# u - (n, t, d)
un, ut, ud = u.shape
u = u.transpose(1,0,2)
# h_0 - (n, h)
hn, hh = h_0.shape
# Wx - (d, h)
# Wh - (h, h)
# b - h
# 初期設定
h_t_1 = h_0
hs = np.zeros((ut, hn, hh))
for t in range(ut):
# RNN
u_t = u[t]
h_t = propagation_func(np.dot(h_t_1, Wh) + np.dot(u_t, Wx) + b)
h_t_1 = h_t
# hの値保持
hs[t] = h_t
# 後設定
h_n = h_t
hs = hs.transpose(1,0,2)
return h_n, hs
def RNN_back(dh_n, dhs, u, hs, h_0, Wx, Wh, b, propagation_func=tanh):
propagation_back_func = eval(propagation_func.__name__ + "_back")
# dh_n - (n, h)
# dhs - (n, t, h)
dhs = dhs.transpose(1,0,2)
# u - (n, t, d)
un, ut, ud = u.shape
u = u.transpose(1,0,2)
# hs - (n, t, h)
hs = hs.transpose(1,0,2)
# h_0 - (n, h)
# Wx - (d, h)
# Wh - (h, h)
# b - h
# 初期設定
dh_t = dh_n
du = np.zeros_like(u)
dWx = np.zeros_like(Wx)
dWh = np.zeros_like(Wh)
db = np.zeros_like(b)
for t in reversed(range(ut)):
# RNN勾配計算
u_t = u[t]
h_t = hs[t]
if t == 0:
h_t_1 = h_0
else:
h_t_1 = hs[t-1]
# 両方向の勾配を加える
dh_t = dh_t + dhs[t]
# tanhの勾配
dp = propagation_back_func(dh_t, None, h_t)
# 各勾配
# h => (n, h)
db += np.sum(dp, axis=0)
# (h, h) => (h, n) @ (n, h)
dWh += np.dot(h_t_1.T, dp)
# (d, h) => (d, n) @ (n, h)
dWx += np.dot(u_t.T, dp)
# (n, h) => (n, h) @ (h, h)
dh_t_1 = np.dot(dp, Wh.T)
# (n, d) => (n, h) @ (h, d)
du_t = np.dot(dp, Wx.T)
dh_t = dh_t_1
# uの勾配の値保持
du[t] = du_t
# 後設定
dh_0 = dh_t_1
du = du.transpose(1,0,2)
return du, dWx, dWh, db, dh_0
def LSTM(u, h_0, c_0, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh):
# u - (n, t, d)
un, ut, ud = u.shape
u = u.transpose(1,0,2)
# h_0 - (n, h)
hn, hh = h_0.shape
# c_0 - (n, h)
cn, ch = c_0.shape
# Wxf, Wxi, Wxg, Who - (d, h)
# Whf, Whi, Whg, Who - (h, h)
# bf, bi, bg, bo - h
# 初期設定
h_t_1 = h_0
c_t_1 = c_0
hs = np.zeros((ut, hn, hh))
cs = np.zeros((ut, cn, ch))
fs = np.zeros((ut, hn, hh))
Is = np.zeros((ut, hn, hh))
gs = np.zeros((ut, hn, hh))
os = np.zeros((ut, hn, hh))
for t in range(ut):
# LSTM
u_t = u[t]
f_t = sigmoid(np.dot(h_t_1, Whf) + np.dot(u_t, Wxf) + bf)
i_t = sigmoid(np.dot(h_t_1, Whi) + np.dot(u_t, Wxi) + bi)
g_t = propagation_func(np.dot(h_t_1, Whg) + np.dot(u_t, Wxg) + bg)
o_t = sigmoid(np.dot(h_t_1, Who) + np.dot(u_t, Wxo) + bo)
c_t = f_t * c_t_1 + i_t * g_t
h_t = o_t * propagation_func(c_t)
# h,c,f,i,g,oの値保持
hs[t] = h_t
cs[t] = c_t
fs[t] = f_t
Is[t] = i_t
gs[t] = g_t
os[t] = o_t
h_t_1 = h_t
c_t_1 = c_t
# 後設定
h_n = h_t
c_n = c_t
hs = hs.transpose(1,0,2)
cs = cs.transpose(1,0,2)
return h_n, hs, c_n, cs, fs, Is, gs, os
def LSTM_back(dh_n, dhs, u, hs, h_0, cs, c_0, fs, Is, gs, os, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh):
propagation_back_func = eval(propagation_func.__name__ + "_back")
# dh_n - (n, h)
# dhs - (n, t, h)
dhs = dhs.transpose(1,0,2)
# u - (n, t, d)
un, ut, ud = u.shape
u = u.transpose(1,0,2)
# hs - (n, t, h)
hs = hs.transpose(1,0,2)
# h_0 - (n, h)
# cs - (n, t, h)
cs = cs.transpose(1,0,2)
# c_0 - (n, h)
# fs, Is, gs, os - (n, t, h)
# Wxf, Wxi, Wxg, Who - (d, h)
# Whf, Whi, Whg, Who - (h, h)
# bf, bi, bg, bo - h
# 初期設定
dh_t = dh_n
dc_t = np.zeros_like(c_0)
du = np.zeros_like(u)
dWxf = np.zeros_like(Wxf)
dWhf = np.zeros_like(Whf)
dbf = np.zeros_like(bf)
dWxi = np.zeros_like(Wxi)
dWhi = np.zeros_like(Whi)
dbi = np.zeros_like(bi)
dWxg = np.zeros_like(Wxg)
dWhg = np.zeros_like(Whg)
dbg = np.zeros_like(bg)
dWxo = np.zeros_like(Wxo)
dWho = np.zeros_like(Who)
dbo = np.zeros_like(bo)
for t in reversed(range(ut)):
# LSTM勾配計算
u_t = u[t]
h_t = hs[t]
c_t = cs[t]
if t == 0:
h_t_1 = h_0
c_t_1 = c_0
else:
h_t_1 = hs[t-1]
c_t_1 = cs[t-1]
o_t = os[t]
g_t = gs[t]
i_t = Is[t]
f_t = fs[t]
# 両方向の勾配を加える
dh_t = dh_t + dhs[t]
# tanhの勾配
dp = dh_t * o_t
p_t = propagation_func(c_t)
dc_t = dc_t + propagation_back_func(dp, None, p_t)
# 各勾配
# o
do = dh_t * p_t
dpo = sigmoid_back(do, None, o_t)
dbo += np.sum(dpo, axis=0)
dWho += np.dot(h_t_1.T, dpo)
dWxo += np.dot(u_t.T, dpo)
# g
dg = dc_t * i_t
dpg = propagation_back_func(dg, None, g_t)
dbg += np.sum(dpg, axis=0)
dWhg += np.dot(h_t_1.T, dpg)
dWxg += np.dot(u_t.T, dpg)
# i
di = dc_t * g_t
dpi = sigmoid_back(di, None, i_t)
dbi += np.sum(dpi, axis=0)
dWhi += np.dot(h_t_1.T, dpi)
dWxi += np.dot(u_t.T, dpi)
# f
df = dc_t * c_t_1
dpf = sigmoid_back(df, None, f_t)
dbf += np.sum(dpf, axis=0)
dWhf += np.dot(h_t_1.T, dpf)
dWxf += np.dot(u_t.T, dpf)
# c
dc_t_1 = dc_t * f_t
# h
dh_t_1 = np.dot(dpf, Whf.T) + np.dot(dpi, Whi.T) + np.dot(dpg, Whg.T) + np.dot(dpo, Who.T)
# u
du_t = np.dot(dpf, Wxf.T) + np.dot(dpi, Wxi.T) + np.dot(dpg, Wxg.T) + np.dot(dpo, Wxo.T)
dh_t = dh_t_1
dc_t = dc_t_1
# uの勾配の値保持
du[t] = du_t
# 後設定
dh_0 = dh_t_1
dc_0 = dc_t_1
du = du.transpose(1,0,2)
return du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0
def GRU(u, h_0, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh):
# u - (n, t, d)
un, ut, ud = u.shape
u = u.transpose(1,0,2)
# h_0 - (n, h)
hn, hh = h_0.shape
# Wxz, Wxr, Wxn - (d, h)
# Whz, Whr, Whn - (h, h)
# bz, br, bn - h
# 初期設定
h_t_1 = h_0
hs = np.zeros((ut, hn, hh))
zs = np.zeros((ut, hn, hh))
rs = np.zeros((ut, hn, hh))
ns = np.zeros((ut, hn, hh))
for t in range(ut):
# GRU
u_t = u[t]
z_t = sigmoid(np.dot(h_t_1, Whz) + np.dot(u_t, Wxz) + bz)
r_t = sigmoid(np.dot(h_t_1, Whr) + np.dot(u_t, Wxr) + br)
n_t = propagation_func(r_t * np.dot(h_t_1, Whn) + np.dot(u_t, Wxn) + bn)
h_t = (1-z_t) * h_t_1 + z_t * n_t
# h,z,r,nの値保持
hs[t] = h_t
zs[t] = z_t
rs[t] = r_t
ns[t] = n_t
h_t_1 = h_t
# 後設定
h_n = h_t
hs = hs.transpose(1,0,2)
return h_n, hs, zs, rs, ns
def GRU_back(dh_n, dhs, u, hs, h_0, zs, rs, ns, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh):
propagation_back_func = eval(propagation_func.__name__ + "_back")
# dh_n - (n, h)
# dhs - (n, t, h)
dhs = dhs.transpose(1,0,2)
# u - (n, t, d)
un, ut, ud = u.shape
u = u.transpose(1,0,2)
# hs - (n, t, h)
hs = hs.transpose(1,0,2)
# h_0 - (n, h)
# zs, rs, ns - (n, t, h)
# Wxz, Wxr, Wxn - (d, h)
# Whz, Whr, Whn - (h, h)
# bz, br, bn - h
# 初期設定
dh_t = dh_n
du = np.zeros_like(u)
dWxz = np.zeros_like(Wxz)
dWhz = np.zeros_like(Whz)
dbz = np.zeros_like(bz)
dWxr = np.zeros_like(Wxr)
dWhr = np.zeros_like(Whr)
dbr = np.zeros_like(br)
dWxn = np.zeros_like(Wxn)
dWhn = np.zeros_like(Whn)
dbn = np.zeros_like(bn)
for t in reversed(range(ut)):
# GRU勾配計算
u_t = u[t]
h_t = hs[t]
if t == 0:
h_t_1 = h_0
else:
h_t_1 = hs[t-1]
z_t = zs[t]
r_t = rs[t]
n_t = ns[t]
# 両方向の勾配を加える
dh_t = dh_t + dhs[t]
# 各勾配
# n
dn = dh_t * z_t
dpn = propagation_back_func(dn, None, n_t)
dbn += np.sum(dpn, axis=0)
dWhn += np.dot(h_t_1.T, dpn * r_t)
dWxn += np.dot(u_t.T, dpn)
# r
dr = dpn * r_t * np.dot(h_t_1, Whn)
dpr = sigmoid_back(dr, None, r_t)
dbr += np.sum(dpr, axis=0)
dWhr += np.dot(h_t_1.T, dpr)
dWxr += np.dot(u_t.T, dpr)
# z
dz = dh_t * n_t - dh_t * h_t_1
dpz = sigmoid_back(dz, None, z_t)
dbz += np.sum(dpz, axis=0)
dWhz += np.dot(h_t_1.T, dpz)
dWxz += np.dot(u_t.T, dpz)
# h
dh_t_1 = np.dot(dpz, Whz.T) + np.dot(dpr, Whr.T) + np.dot(dpn * r_t, Whn.T) + dh_t * (1 - z_t)
# u
du_t = np.dot(dpz, Wxz.T) + np.dot(dpr, Wxr.T) + np.dot(dpn, Wxn.T)
dh_t = dh_t_1
# uの勾配の値保持
du[t] = du_t
# 後設定
dh_0 = dh_t_1
du = du.transpose(1,0,2)
return du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0
def RNN_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
t, d = d_prev
d_next = (t, h)
if not params.get("return_hs"):
d_next = h
Wx = Wx_init_func(d, h, **Wx_init_params)
Wh = Wh_init_func(h, h, **Wh_init_params)
b = bias_init_func(h, **bias_init_params)
return d_next, {"Wx":Wx, "Wh":Wh, "b":b}
def RNN_init_optimizer():
sWx = {}
sWh = {}
sb = {}
return {"sWx":sWx, "sWh":sWh, "sb":sb}
def RNN_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = np.zeros((u.shape[0], weights["Wh"].shape[0]))
# RNN
h_n, hs = func(u, h_0, weights["Wx"], weights["Wh"], weights["b"], propagation_func)
# RNN最下層以外
if params.get("return_hs"):
z = hs
# RNN最下層
else:
z = h_n
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
weight_decay_r += weight_decay["func"](weights["Wx"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wh"], **weight_decay["params"])
return {"u":u, "z":z, "h_0":h_0, "hs":hs}, weight_decay_r
def RNN_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
# RNN最下層以外
if return_hs:
dh_n = np.zeros_like(us["h_0"])
dhs = dz
# RNN最下層
else:
dh_n = dz
dhs = np.zeros_like(us["hs"])
# RNNの勾配計算
du, dWx, dWh, db, dh_0 = back_func(dh_n, dhs, us["u"], us["hs"], us["h_0"], weights["Wx"], weights["Wh"], weights["b"], propagation_func)
# 重み減衰対応
if weight_decay is not None:
dWx += weight_decay["back_func"](weights["Wx"], **weight_decay["params"])
dWh += weight_decay["back_func"](weights["Wh"], **weight_decay["params"])
return {"du":du, "dWx":dWx, "dWh":dWh, "db":db}
def RNN_update_weight(func, du, weights, optimizer_stats, **params):
weights["Wx"], optimizer_stats["sWx"] = func(weights["Wx"], du["dWx"], **params, **optimizer_stats["sWx"])
weights["Wh"], optimizer_stats["sWh"] = func(weights["Wh"], du["dWh"], **params, **optimizer_stats["sWh"])
weights["b"], optimizer_stats["sb"] = func(weights["b"], du["db"], **params, **optimizer_stats["sb"])
return weights, optimizer_stats
def LSTM_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
t, d = d_prev
d_next = (t, h)
if not params.get("return_hs"):
d_next = h
Wxf = Wx_init_func(d, h, **Wx_init_params)
Whf = Wh_init_func(h, h, **Wh_init_params)
bf = bias_init_func(h, **bias_init_params)
Wxi = Wx_init_func(d, h, **Wx_init_params)
Whi = Wh_init_func(h, h, **Wh_init_params)
bi = bias_init_func(h, **bias_init_params)
Wxg = Wx_init_func(d, h, **Wx_init_params)
Whg = Wh_init_func(h, h, **Wh_init_params)
bg = bias_init_func(h, **bias_init_params)
Wxo = Wx_init_func(d, h, **Wx_init_params)
Who = Wh_init_func(h, h, **Wh_init_params)
bo = bias_init_func(h, **bias_init_params)
return d_next, {"Wxf":Wxf, "Whf":Whf, "bf":bf, "Wxi":Wxi, "Whi":Whi, "bi":bi, "Wxg":Wxg, "Whg":Whg, "bg":bg, "Wxo":Wxo, "Who":Who, "bo":bo}
def LSTM_init_optimizer():
sWxf = {}
sWhf = {}
sbf = {}
sWxi = {}
sWhi = {}
sbi = {}
sWxg = {}
sWhg = {}
sbg = {}
sWxo = {}
sWho = {}
sbo = {}
return {"sWxf":sWxf, "sWhf":sWhf, "sbf":sbf,
"sWxi":sWxi, "sWhi":sWhi, "sbi":sbi,
"sWxg":sWxg, "sWhg":sWhg, "sbg":sbg,
"sWxo":sWxo, "sWho":sWho, "sbo":sbo}
def LSTM_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = np.zeros((u.shape[0], weights["Whf"].shape[0]))
c_0 = params.get("c_0")
if c_0 is None:
c_0 = np.zeros((u.shape[0], weights["Whf"].shape[0]))
# LSTM
h_n, hs, c_n, cs, fs, Is, gs, os = func(u, h_0, c_0,
weights["Wxf"], weights["Whf"], weights["bf"],
weights["Wxi"], weights["Whi"], weights["bi"],
weights["Wxg"], weights["Whg"], weights["bg"],
weights["Wxo"], weights["Who"], weights["bo"], propagation_func)
# LSTM最下層以外
if params.get("return_hs"):
z = hs
# LSTM最下層
else:
z = h_n
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
weight_decay_r += weight_decay["func"](weights["Wxf"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whf"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxi"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whi"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxg"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whg"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxo"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Who"], **weight_decay["params"])
return {"u":u, "z":z, "h_0":h_0, "hs":hs, "c_0":c_0, "cs":cs,
"fs":fs, "Is":Is, "gs":gs, "os":os}, weight_decay_r
def LSTM_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
# LSTM最下層以外
if return_hs:
dh_n = np.zeros_like(us["h_0"])
dhs = dz
# LSTM最下層
else:
dh_n = dz
dhs = np.zeros_like(us["hs"])
# LSTMの勾配計算
du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0 = back_func(dh_n, dhs,
us["u"], us["hs"], us["h_0"], us["cs"], us["c_0"],
us["fs"], us["Is"], us["gs"], us["os"],
weights["Wxf"], weights["Whf"], weights["bf"],
weights["Wxi"], weights["Whi"], weights["bi"],
weights["Wxg"], weights["Whg"], weights["bg"],
weights["Wxo"], weights["Who"], weights["bo"],
propagation_func)
# 重み減衰対応
if weight_decay is not None:
dWxf += weight_decay["back_func"](weights["Wxf"], **weight_decay["params"])
dWhf += weight_decay["back_func"](weights["Whf"], **weight_decay["params"])
dWxi += weight_decay["back_func"](weights["Wxi"], **weight_decay["params"])
dWhi += weight_decay["back_func"](weights["Whi"], **weight_decay["params"])
dWxg += weight_decay["back_func"](weights["Wxg"], **weight_decay["params"])
dWhg += weight_decay["back_func"](weights["Whg"], **weight_decay["params"])
dWxo += weight_decay["back_func"](weights["Wxo"], **weight_decay["params"])
dWho += weight_decay["back_func"](weights["Who"], **weight_decay["params"])
return {"du":du, "dWxf":dWxf, "dWhf":dWhf, "dbf":dbf, "dWxi":dWxi, "dWhi":dWhi, "dbi":dbi, "dWxg":dWxg, "dWhg":dWhg, "dbg":dbg, "dWxo":dWxo, "dWho":dWho, "dbo":dbo}
def LSTM_update_weight(func, du, weights, optimizer_stats, **params):
weights["Wxf"], optimizer_stats["sWxf"] = func(weights["Wxf"], du["dWxf"], **params, **optimizer_stats["sWxf"])
weights["Whf"], optimizer_stats["sWhf"] = func(weights["Whf"], du["dWhf"], **params, **optimizer_stats["sWhf"])
weights["bf"], optimizer_stats["sbf"] = func(weights["bf"], du["dbf"], **params, **optimizer_stats["sbf"])
weights["Wxi"], optimizer_stats["sWxi"] = func(weights["Wxi"], du["dWxi"], **params, **optimizer_stats["sWxi"])
weights["Whi"], optimizer_stats["sWhi"] = func(weights["Whi"], du["dWhi"], **params, **optimizer_stats["sWhi"])
weights["bi"], optimizer_stats["sbi"] = func(weights["bi"], du["dbi"], **params, **optimizer_stats["sbi"])
weights["Wxg"], optimizer_stats["sWxg"] = func(weights["Wxg"], du["dWxg"], **params, **optimizer_stats["sWxg"])
weights["Whg"], optimizer_stats["sWhg"] = func(weights["Whg"], du["dWhg"], **params, **optimizer_stats["sWhg"])
weights["bg"], optimizer_stats["sbg"] = func(weights["bg"], du["dbg"], **params, **optimizer_stats["sbg"])
weights["Wxo"], optimizer_stats["sWxo"] = func(weights["Wxo"], du["dWxo"], **params, **optimizer_stats["sWxo"])
weights["Who"], optimizer_stats["sWho"] = func(weights["Who"], du["dWho"], **params, **optimizer_stats["sWho"])
weights["bo"], optimizer_stats["sbo"] = func(weights["bo"], du["dbo"], **params, **optimizer_stats["sbo"])
return weights, optimizer_stats
def GRU_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, **params):
t, d = d_prev
d_next = (t, h)
if not params.get("return_hs"):
d_next = h
Wxz = Wx_init_func(d, h, **Wx_init_params)
Whz = Wh_init_func(h, h, **Wh_init_params)
bz = bias_init_func(h, **bias_init_params)
Wxr = Wx_init_func(d, h, **Wx_init_params)
Whr = Wh_init_func(h, h, **Wh_init_params)
br = bias_init_func(h, **bias_init_params)
Wxn = Wx_init_func(d, h, **Wx_init_params)
Whn = Wh_init_func(h, h, **Wh_init_params)
bn = bias_init_func(h, **bias_init_params)
return d_next, {"Wxz":Wxz, "Whz":Whz, "bz":bz, "Wxr":Wxr, "Whr":Whr, "br":br, "Wxn":Wxn, "Whn":Whn, "bn":bn}
def GRU_init_optimizer():
sWxz = {}
sWhz = {}
sbz = {}
sWxr = {}
sWhr = {}
sbr = {}
sWxn = {}
sWhn = {}
sbn = {}
return {"sWxz":sWxz, "sWhz":sWhz, "sbz":sbz,
"sWxr":sWxr, "sWhr":sWhr, "sbr":sbr,
"sWxn":sWxn, "sWhn":sWhn, "sbn":sbn}
def GRU_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = np.zeros((u.shape[0], weights["Whz"].shape[0]))
# GRU
h_n, hs, zs, rs, ns = func(u, h_0,
weights["Wxz"], weights["Whz"], weights["bz"],
weights["Wxr"], weights["Whr"], weights["br"],
weights["Wxn"], weights["Whn"], weights["bn"],
propagation_func)
# GRU最下層以外
if params.get("return_hs"):
z = hs
# GRU最下層
else:
z = h_n
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
weight_decay_r += weight_decay["func"](weights["Wxz"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whz"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxr"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whr"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxn"], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whn"], **weight_decay["params"])
return {"u":u, "z":z, "h_0":h_0, "hs":hs, "zs":zs, "rs":rs, "ns":ns}, weight_decay_r
def GRU_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, **params):
# GRU最下層以外
if return_hs:
dh_n = np.zeros_like(us["h_0"])
dhs = dz
# GRU最下層
else:
dh_n = dz
dhs = np.zeros_like(us["hs"])
# GRUの勾配計算
du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0 = back_func(dh_n, dhs,
us["u"], us["hs"], us["h_0"],
us["zs"], us["rs"], us["ns"],
weights["Wxz"], weights["Whz"], weights["bz"],
weights["Wxr"], weights["Whr"], weights["br"],
weights["Wxn"], weights["Whn"], weights["bn"],
propagation_func)
# 重み減衰対応
if weight_decay is not None:
dWxz += weight_decay["back_func"](weights["Wxz"], **weight_decay["params"])
dWhz += weight_decay["back_func"](weights["Whz"], **weight_decay["params"])
dWxr += weight_decay["back_func"](weights["Wxr"], **weight_decay["params"])
dWhr += weight_decay["back_func"](weights["Whr"], **weight_decay["params"])
dWxn += weight_decay["back_func"](weights["Wxn"], **weight_decay["params"])
dWhn += weight_decay["back_func"](weights["Whn"], **weight_decay["params"])
return {"du":du, "dWxz":dWxz, "dWhz":dWhz, "dbz":dbz, "dWxr":dWxr, "dWhr":dWhr, "dbr":dbr, "dWxn":dWxn, "dWhn":dWhn, "dbn":dbn}
def GRU_update_weight(func, du, weights, optimizer_stats, **params):
weights["Wxz"], optimizer_stats["sWxz"] = func(weights["Wxz"], du["dWxz"], **params, **optimizer_stats["sWxz"])
weights["Whz"], optimizer_stats["sWhz"] = func(weights["Whz"], du["dWhz"], **params, **optimizer_stats["sWhz"])
weights["bz"], optimizer_stats["sbz"] = func(weights["bz"], du["dbz"], **params, **optimizer_stats["sbz"])
weights["Wxr"], optimizer_stats["sWxr"] = func(weights["Wxr"], du["dWxr"], **params, **optimizer_stats["sWxr"])
weights["Whr"], optimizer_stats["sWhr"] = func(weights["Whr"], du["dWhr"], **params, **optimizer_stats["sWhr"])
weights["br"], optimizer_stats["sbr"] = func(weights["br"], du["dbr"], **params, **optimizer_stats["sbr"])
weights["Wxn"], optimizer_stats["sWxn"] = func(weights["Wxn"], du["dWxn"], **params, **optimizer_stats["sWxn"])
weights["Whn"], optimizer_stats["sWhn"] = func(weights["Whn"], du["dWhn"], **params, **optimizer_stats["sWhn"])
weights["bn"], optimizer_stats["sbn"] = func(weights["bn"], du["dbn"], **params, **optimizer_stats["sbn"])
return weights, optimizer_stats
def BiRNN(u, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
# 初期化
h_n, hs = {}, {}
# 順方向
h_n["f"], hs["f"] = RNN(u_dr["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
# 逆方向
h_n["r"], hs["r"] = RNN(u_dr["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
# h_n - (n, h)
# hs - (n, t, h)
if merge == "concat":
# 結合
h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
elif merge == "sum":
# 合計
h_n_merge = h_n["f"] + h_n["r"]
hs_merge = hs["f"] + hs["r"][:, ::-1]
elif merge == "dictionary":
# 辞書
h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
return h_n_merge, hs_merge, h_n, hs
def BiRNN_back(dh_n_merge, dhs_merge, u, hs, h_0, Wx, Wh, b, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
h = Wh["f"].shape[0]
# dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
# dhs - (n, t, h)
dh_n, dhs = {}, {}
if merge == "concat":
# 結合
dh_n["f"] = dh_n_merge[:, :h]
dh_n["r"] = dh_n_merge[:, h:]
dhs["f"] = dhs_merge[:, :, :h]
dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
elif merge == "sum":
# 合計
dh_n["f"] = dh_n_merge
dh_n["r"] = dh_n_merge
dhs["f"] = dhs_merge
dhs["r"] = dhs_merge[:, ::-1]
elif merge == "dictionary":
# 辞書
dh_n["f"] = dh_n_merge["f"]
dh_n["r"] = dh_n_merge["r"]
dhs["f"] = dhs_merge["f"]
dhs["r"] = dhs_merge["r"][:, ::-1]
# 初期化
du = {}
dWx, dWh, db = {}, {}, {}
dh_0 = {}
# 順方向
du["f"], dWx["f"], dWh["f"], db["f"], dh_0["f"] = RNN_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], Wx["f"], Wh["f"], b["f"], propagation_func)
# 逆方向
du["r"], dWx["r"], dWh["r"], db["r"], dh_0["r"] = RNN_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], Wx["r"], Wh["r"], b["r"], propagation_func)
# du - (n, t, d)
if not input_dictionary:
du_merge = du["f"] + du["r"][:, ::-1]
else:
du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
return du_merge, dWx, dWh, db, dh_0
def BiLSTM(u, h_0, c_0, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
# 初期化
h_n, hs, c_n, cs = {}, {}, {}, {}
fs, Is, gs, os = {}, {}, {}, {}
# 順方向
h_n["f"], hs["f"], c_n["f"], cs["f"], fs["f"], Is["f"], gs["f"], os["f"] = LSTM(u_dr["f"], h_0["f"], c_0["f"], Wxf["f"], Whf["f"], bf["f"], Wxi["f"], Whi["f"], bi["f"], Wxg["f"], Whg["f"], bg["f"], Wxo["f"], Who["f"], bo["f"], propagation_func)
# 逆方向
h_n["r"], hs["r"], c_n["r"], cs["r"], fs["r"], Is["r"], gs["r"], os["r"] = LSTM(u_dr["r"], h_0["r"], c_0["r"], Wxf["r"], Whf["r"], bf["r"], Wxi["r"], Whi["r"], bi["r"], Wxg["r"], Whg["r"], bg["r"], Wxo["r"], Who["r"], bo["r"], propagation_func)
# h_n, c_n - (n, h)
# hs, cs - (n, t, h)
if merge == "concat":
# 結合
h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
elif merge == "sum":
# 合計
h_n_merge = h_n["f"] + h_n["r"]
hs_merge = hs["f"] + hs["r"][:, ::-1]
elif merge == "dictionary":
# 辞書
h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
return h_n_merge, hs_merge, h_n, hs, c_n, cs, fs, Is, gs, os
def BiLSTM_back(dh_n_merge, dhs_merge, u, hs, h_0, cs, c_0, fs, Is, gs, os, Wxf, Whf, bf, Wxi, Whi, bi, Wxg, Whg, bg, Wxo, Who, bo, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
h = Whf["f"].shape[0]
# dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
# dhs - (n, t, h)
dh_n, dhs = {}, {}
if merge == "concat":
# 結合
dh_n["f"] = dh_n_merge[:, :h]
dh_n["r"] = dh_n_merge[:, h:]
dhs["f"] = dhs_merge[:, :, :h]
dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
elif merge == "sum":
# 合計
dh_n["f"] = dh_n_merge
dh_n["r"] = dh_n_merge
dhs["f"] = dhs_merge
dhs["r"] = dhs_merge[:, ::-1]
elif merge == "dictionary":
# 辞書
dh_n["f"] = dh_n_merge["f"]
dh_n["r"] = dh_n_merge["r"]
dhs["f"] = dhs_merge["f"]
dhs["r"] = dhs_merge["r"][:, ::-1]
# 初期化
du = {}
dWxf, dWhf, dbf = {}, {}, {}
dWxi, dWhi, dbi = {}, {}, {}
dWxg, dWhg, dbg = {}, {}, {}
dWxo, dWho, dbo = {}, {}, {}
dh_0, dc_0 = {}, {}
# 順方向
du["f"], dWxf["f"], dWhf["f"], dbf["f"], dWxi["f"], dWhi["f"], dbi["f"], dWxg["f"], dWhg["f"], dbg["f"], dWxo["f"], dWho["f"], dbo["f"], dh_0["f"], dc_0["f"] = \
LSTM_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], cs["f"], c_0["f"], fs["f"], Is["f"], gs["f"], os["f"], Wxf["f"], Whf["f"], bf["f"], Wxi["f"], Whi["f"], bi["f"], Wxg["f"], Whg["f"], bg["f"], Wxo["f"], Who["f"], bo["f"], propagation_func)
# 逆方向
du["r"], dWxf["r"], dWhf["r"], dbf["r"], dWxi["r"], dWhi["r"], dbi["r"], dWxg["r"], dWhg["r"], dbg["r"], dWxo["r"], dWho["r"], dbo["r"], dh_0["r"], dc_0["r"] = \
LSTM_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], cs["r"], c_0["r"], fs["r"], Is["r"], gs["r"], os["r"], Wxf["r"], Whf["r"], bf["r"], Wxi["r"], Whi["r"], bi["r"], Wxg["r"], Whg["r"], bg["r"], Wxo["r"], Who["r"], bo["r"], propagation_func)
# du - (n, t, d)
if not input_dictionary:
du_merge = du["f"] + du["r"][:, ::-1]
else:
du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
return du_merge, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0
def BiGRU(u, h_0, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
# 初期化
h_n, hs = {}, {}
zs, rs, ns = {}, {}, {}
# 順方向
h_n["f"], hs["f"], zs["f"], rs["f"], ns["f"] = GRU(u_dr["f"], h_0["f"], Wxz["f"], Whz["f"], bz["f"], Wxr["f"], Whr["f"], br["f"], Wxn["f"], Whn["f"], bn["f"], propagation_func)
# 逆方向
h_n["r"], hs["r"], zs["r"], rs["r"], ns["r"] = GRU(u_dr["r"], h_0["r"], Wxz["r"], Whz["r"], bz["r"], Wxr["r"], Whr["r"], br["r"], Wxn["r"], Whn["r"], bn["r"], propagation_func)
# h_n - (n, h)
# hs - (n, t, h)
if merge == "concat":
# 結合
h_n_merge = np.concatenate([h_n["f"], h_n["r"]], axis=1)
hs_merge = np.concatenate([hs["f"], hs["r"][:, ::-1]], axis=2)
elif merge == "sum":
# 合計
h_n_merge = h_n["f"] + h_n["r"]
hs_merge = hs["f"] + hs["r"][:, ::-1]
elif merge == "dictionary":
# 辞書
h_n_merge = {"f":h_n["f"], "r":h_n["r"]}
hs_merge = {"f":hs["f"], "r":hs["r"][:, ::-1]}
return h_n_merge, hs_merge, h_n, hs, zs, rs, ns
def BiGRU_back(dh_n_merge, dhs_merge, u, hs, h_0, zs, rs, ns, Wxz, Whz, bz, Wxr, Whr, br, Wxn, Whn, bn, propagation_func=tanh, merge="concat", input_dictionary=False):
# データ設定(逆方向は、時刻を反転)
u_dr = {}
if not input_dictionary:
u_dr["f"] = u
u_dr["r"] = u[:, ::-1]
else:
u_dr["f"] = u["f"]
u_dr["r"] = u["r"][:, ::-1]
h = Whz["f"].shape[0]
# dh_n - (n, h*2)(concatの場合)、(n, h)(sumの場合)
# dhs - (n, t, h)
dh_n, dhs = {}, {}
if merge == "concat":
# 結合
dh_n["f"] = dh_n_merge[:, :h]
dh_n["r"] = dh_n_merge[:, h:]
dhs["f"] = dhs_merge[:, :, :h]
dhs["r"] = dhs_merge[:, :, h:][:, ::-1]
elif merge == "sum":
# 合計
dh_n["f"] = dh_n_merge
dh_n["r"] = dh_n_merge
dhs["f"] = dhs_merge
dhs["r"] = dhs_merge[:, ::-1]
elif merge == "dictionary":
# 辞書
dh_n["f"] = dh_n_merge["f"]
dh_n["r"] = dh_n_merge["r"]
dhs["f"] = dhs_merge["f"]
dhs["r"] = dhs_merge["r"][:, ::-1]
# 初期化
du = {}
dWxz, dWhz, dbz = {}, {}, {}
dWxr, dWhr, dbr = {}, {}, {}
dWxn, dWhn, dbn = {}, {}, {}
dh_0 = {}
# 順方向
du["f"], dWxz["f"], dWhz["f"], dbz["f"], dWxr["f"], dWhr["f"], dbr["f"], dWxn["f"], dWhn["f"], dbn["f"], dh_0["f"] = \
GRU_back(dh_n["f"], dhs["f"], u_dr["f"], hs["f"], h_0["f"], zs["f"], rs["f"], ns["f"], Wxz["f"], Whz["f"], bz["f"], Wxr["f"], Whr["f"], br["f"], Wxn["f"], Whn["f"], bn["f"], propagation_func)
# 逆方向
du["r"], dWxz["r"], dWhz["r"], dbz["r"], dWxr["r"], dWhr["r"], dbr["r"], dWxn["r"], dWhn["r"], dbn["r"], dh_0["r"] = \
GRU_back(dh_n["r"], dhs["r"], u_dr["r"], hs["r"], h_0["r"], zs["r"], rs["r"], ns["r"], Wxz["r"], Whz["r"], bz["r"], Wxr["r"], Whr["r"], br["r"], Wxn["r"], Whn["r"], bn["r"], propagation_func)
# du - (n, t, d)
if not input_dictionary:
du_merge = du["f"] + du["r"][:, ::-1]
else:
du_merge = {"f":du["f"], "r":du["r"][:, ::-1]}
return du_merge, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0
def BiRNN_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
t, d = d_prev
h_next = h
# concatの場合、出力は2倍
if merge == "concat":
h_next = h * 2
d_next = (t, h_next)
if not params.get("return_hs"):
d_next = h_next
Wx, Wh, b = {}, {}, {}
for dr in ["f", "r"]:
Wx[dr] = Wx_init_func(d, h, **Wx_init_params)
Wh[dr] = Wh_init_func(h, h, **Wh_init_params)
b[dr] = bias_init_func(h, **bias_init_params)
return d_next, {"Wx":Wx, "Wh":Wh, "b":b}
def BiRNN_init_optimizer():
sWx, sWh, sb = {}, {}, {}
for dr in ["f", "r"]:
sWx[dr] = {}
sWh[dr] = {}
sb[dr] = {}
return {"sWx":sWx, "sWh":sWh, "sb":sb}
def BiRNN_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
h_0[dr] = np.zeros((u.shape[0], weights["Wh"][dr].shape[0]))
else:
h_0[dr] = np.zeros((u[dr].shape[0], weights["Wh"][dr].shape[0]))
# RNN
h_n_merge, hs_merge, h_n, hs = func(u, h_0, weights["Wx"], weights["Wh"], weights["b"], propagation_func, merge, input_dictionary)
# RNN最下層以外
if params.get("return_hs"):
z = hs_merge
# RNN最下層
else:
z = h_n_merge
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
for dr in ["f", "r"]:
weight_decay_r += weight_decay["func"](weights["Wx"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wh"][dr], **weight_decay["params"])
return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs}, weight_decay_r
def BiRNN_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):
# RNN最下層以外
if return_hs:
if merge != "dictionary":
dh_n_merge = np.zeros_like(us["h_n_merge"])
else:
dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
dhs_merge = dz
# RNN最下層
else:
dh_n_merge = dz
if merge != "dictionary":
dhs_merge = np.zeros_like(us["hs_merge"])
else:
dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
# RNNの勾配計算
du, dWx, dWh, db, dh_0 = back_func(dh_n_merge, dhs_merge, us["u"], us["hs"], us["h_0"], weights["Wx"], weights["Wh"], weights["b"], propagation_func, merge, input_dictionary)
# 重み減衰対応
if weight_decay is not None:
for dr in ["f", "r"]:
dWx[dr] += weight_decay["back_func"](weights["Wx"][dr], **weight_decay["params"])
dWh[dr] += weight_decay["back_func"](weights["Wh"][dr], **weight_decay["params"])
return {"du":du, "dWx":dWx, "dWh":dWh, "db":db}
def BiRNN_update_weight(func, du, weights, optimizer_stats, **params):
for dr in ["f", "r"]:
weights["Wx"][dr], optimizer_stats["sWx"][dr] = func(weights["Wx"][dr], du["dWx"][dr], **params, **optimizer_stats["sWx"][dr])
weights["Wh"][dr], optimizer_stats["sWh"][dr] = func(weights["Wh"][dr], du["dWh"][dr], **params, **optimizer_stats["sWh"][dr])
weights["b"][dr], optimizer_stats["sb"][dr] = func(weights["b"][dr], du["db"][dr], **params, **optimizer_stats["sb"][dr])
return weights, optimizer_stats
def BiLSTM_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
t, d = d_prev
h_next = h
# concatの場合、出力は2倍
if merge == "concat":
h_next = h * 2
d_next = (t, h_next)
if not params.get("return_hs"):
d_next = h_next
Wxf, Whf, bf = {}, {}, {}
Wxi, Whi, bi = {}, {}, {}
Wxg, Whg, bg = {}, {}, {}
Wxo, Who, bo = {}, {}, {}
for dr in ["f", "r"]:
Wxf[dr] = Wx_init_func(d, h, **Wx_init_params)
Whf[dr] = Wh_init_func(h, h, **Wh_init_params)
bf[dr] = bias_init_func(h, **bias_init_params)
Wxi[dr] = Wx_init_func(d, h, **Wx_init_params)
Whi[dr] = Wh_init_func(h, h, **Wh_init_params)
bi[dr] = bias_init_func(h, **bias_init_params)
Wxg[dr] = Wx_init_func(d, h, **Wx_init_params)
Whg[dr] = Wh_init_func(h, h, **Wh_init_params)
bg[dr] = bias_init_func(h, **bias_init_params)
Wxo[dr] = Wx_init_func(d, h, **Wx_init_params)
Who[dr] = Wh_init_func(h, h, **Wh_init_params)
bo[dr] = bias_init_func(h, **bias_init_params)
return d_next, {"Wxf":Wxf, "Whf":Whf, "bf":bf, "Wxi":Wxi, "Whi":Whi, "bi":bi, "Wxg":Wxg, "Whg":Whg, "bg":bg, "Wxo":Wxo, "Who":Who, "bo":bo}
def BiLSTM_init_optimizer():
sWxf, sWhf, sbf = {}, {}, {}
sWxi, sWhi, sbi = {}, {}, {}
sWxg, sWhg, sbg = {}, {}, {}
sWxo, sWho, sbo = {}, {}, {}
for dr in ["f", "r"]:
sWxf[dr] = {}
sWhf[dr] = {}
sbf[dr] = {}
sWxi[dr] = {}
sWhi[dr] = {}
sbi[dr] = {}
sWxg[dr] = {}
sWhg[dr] = {}
sbg[dr] = {}
sWxo[dr] = {}
sWho[dr] = {}
sbo[dr] = {}
return {"sWxf":sWxf, "sWhf":sWhf, "sbf":sbf,
"sWxi":sWxi, "sWhi":sWhi, "sbi":sbi,
"sWxg":sWxg, "sWhg":sWhg, "sbg":sbg,
"sWxo":sWxo, "sWho":sWho, "sbo":sbo}
def BiLSTM_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
h_0[dr] = np.zeros((u.shape[0], weights["Whf"][dr].shape[0]))
else:
h_0[dr] = np.zeros((u[dr].shape[0], weights["Whf"][dr].shape[0]))
c_0 = params.get("c_0")
if c_0 is None:
c_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
c_0[dr] = np.zeros((u.shape[0], weights["Whf"][dr].shape[0]))
else:
c_0[dr] = np.zeros((u[dr].shape[0], weights["Whf"][dr].shape[0]))
# LSTM
h_n_merge, hs_merge, h_n, hs, c_n, cs, fs, Is, gs, os = func(u, h_0, c_0,
weights["Wxf"], weights["Whf"], weights["bf"],
weights["Wxi"], weights["Whi"], weights["bi"],
weights["Wxg"], weights["Whg"], weights["bg"],
weights["Wxo"], weights["Who"], weights["bo"],
propagation_func, merge, input_dictionary)
# LSTM最下層以外
if params.get("return_hs"):
z = hs_merge
# LSTM最下層
else:
z = h_n_merge
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
for dr in ["f", "r"]:
weight_decay_r += weight_decay["func"](weights["Wxf"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whf"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxi"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whi"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxg"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whg"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxo"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Who"][dr], **weight_decay["params"])
return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs, "c_0":c_0, "c_n":c_n, "cs":cs,
"fs":fs, "Is":Is, "gs":gs, "os":os}, weight_decay_r
def BiLSTM_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):
# LSTM最下層以外
if return_hs:
if merge != "dictionary":
dh_n_merge = np.zeros_like(us["h_n_merge"])
else:
dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
dhs_merge = dz
# LSTM最下層
else:
dh_n_merge = dz
if merge != "dictionary":
dhs_merge = np.zeros_like(us["hs_merge"])
else:
dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
# LSTMの勾配計算
du, dWxf, dWhf, dbf, dWxi, dWhi, dbi, dWxg, dWhg, dbg, dWxo, dWho, dbo, dh_0, dc_0 = back_func(dh_n_merge, dhs_merge,
us["u"], us["hs"], us["h_0"], us["cs"], us["c_0"],
us["fs"], us["Is"], us["gs"], us["os"],
weights["Wxf"], weights["Whf"], weights["bf"],
weights["Wxi"], weights["Whi"], weights["bi"],
weights["Wxg"], weights["Whg"], weights["bg"],
weights["Wxo"], weights["Who"], weights["bo"],
propagation_func, merge, input_dictionary)
# 重み減衰対応
if weight_decay is not None:
for dr in ["f", "r"]:
dWxf[dr] += weight_decay["back_func"](weights["Wxf"][dr], **weight_decay["params"])
dWhf[dr] += weight_decay["back_func"](weights["Whf"][dr], **weight_decay["params"])
dWxi[dr] += weight_decay["back_func"](weights["Wxi"][dr], **weight_decay["params"])
dWhi[dr] += weight_decay["back_func"](weights["Whi"][dr], **weight_decay["params"])
dWxg[dr] += weight_decay["back_func"](weights["Wxg"][dr], **weight_decay["params"])
dWhg[dr] += weight_decay["back_func"](weights["Whg"][dr], **weight_decay["params"])
dWxo[dr] += weight_decay["back_func"](weights["Wxo"][dr], **weight_decay["params"])
dWho[dr] += weight_decay["back_func"](weights["Who"][dr], **weight_decay["params"])
return {"du":du, "dWxf":dWxf, "dWhf":dWhf, "dbf":dbf, "dWxi":dWxi, "dWhi":dWhi, "dbi":dbi, "dWxg":dWxg, "dWhg":dWhg, "dbg":dbg, "dWxo":dWxo, "dWho":dWho, "dbo":dbo}
def BiLSTM_update_weight(func, du, weights, optimizer_stats, **params):
for dr in ["f", "r"]:
weights["Wxf"][dr], optimizer_stats["sWxf"][dr] = func(weights["Wxf"][dr], du["dWxf"][dr], **params, **optimizer_stats["sWxf"][dr])
weights["Whf"][dr], optimizer_stats["sWhf"][dr] = func(weights["Whf"][dr], du["dWhf"][dr], **params, **optimizer_stats["sWhf"][dr])
weights["bf"][dr], optimizer_stats["sbf"][dr] = func(weights["bf"][dr], du["dbf"][dr], **params, **optimizer_stats["sbf"][dr])
weights["Wxi"][dr], optimizer_stats["sWxi"][dr] = func(weights["Wxi"][dr], du["dWxi"][dr], **params, **optimizer_stats["sWxi"][dr])
weights["Whi"][dr], optimizer_stats["sWhi"][dr] = func(weights["Whi"][dr], du["dWhi"][dr], **params, **optimizer_stats["sWhi"][dr])
weights["bi"][dr], optimizer_stats["sbi"][dr] = func(weights["bi"][dr], du["dbi"][dr], **params, **optimizer_stats["sbi"][dr])
weights["Wxg"][dr], optimizer_stats["sWxg"][dr] = func(weights["Wxg"][dr], du["dWxg"][dr], **params, **optimizer_stats["sWxg"][dr])
weights["Whg"][dr], optimizer_stats["sWhg"][dr] = func(weights["Whg"][dr], du["dWhg"][dr], **params, **optimizer_stats["sWhg"][dr])
weights["bg"][dr], optimizer_stats["sbg"][dr] = func(weights["bg"][dr], du["dbg"][dr], **params, **optimizer_stats["sbg"][dr])
weights["Wxo"][dr], optimizer_stats["sWxo"][dr] = func(weights["Wxo"][dr], du["dWxo"][dr], **params, **optimizer_stats["sWxo"][dr])
weights["Who"][dr], optimizer_stats["sWho"][dr] = func(weights["Who"][dr], du["dWho"][dr], **params, **optimizer_stats["sWho"][dr])
weights["bo"][dr], optimizer_stats["sbo"][dr] = func(weights["bo"][dr], du["dbo"][dr], **params, **optimizer_stats["sbo"][dr])
return weights, optimizer_stats
def BiGRU_init_layer(d_prev, h, Wx_init_func=glorot_normal, Wx_init_params={}, Wh_init_func=orthogonal, Wh_init_params={}, bias_init_func=zeros_b, bias_init_params={}, merge="concat", **params):
t, d = d_prev
h_next = h
# concatの場合、出力は2倍
if merge == "concat":
h_next = h * 2
d_next = (t, h_next)
if not params.get("return_hs"):
d_next = h_next
Wxz, Whz, bz = {}, {}, {}
Wxr, Whr, br = {}, {}, {}
Wxn, Whn, bn = {}, {}, {}
for dr in ["f", "r"]:
Wxz[dr] = Wx_init_func(d, h, **Wx_init_params)
Whz[dr] = Wh_init_func(h, h, **Wh_init_params)
bz[dr] = bias_init_func(h, **bias_init_params)
Wxr[dr] = Wx_init_func(d, h, **Wx_init_params)
Whr[dr] = Wh_init_func(h, h, **Wh_init_params)
br[dr] = bias_init_func(h, **bias_init_params)
Wxn[dr] = Wx_init_func(d, h, **Wx_init_params)
Whn[dr] = Wh_init_func(h, h, **Wh_init_params)
bn[dr] = bias_init_func(h, **bias_init_params)
return d_next, {"Wxz":Wxz, "Whz":Whz, "bz":bz, "Wxr":Wxr, "Whr":Whr, "br":br, "Wxn":Wxn, "Whn":Whn, "bn":bn}
def BiGRU_init_optimizer():
sWxz, sWhz, sbz = {}, {}, {}
sWxr, sWhr, sbr = {}, {}, {}
sWxn, sWhn, sbn = {}, {}, {}
for dr in ["f", "r"]:
sWxz[dr] = {}
sWhz[dr] = {}
sbz[dr] = {}
sWxr[dr] = {}
sWhr[dr] = {}
sbr[dr] = {}
sWxn[dr] = {}
sWhn[dr] = {}
sbn[dr] = {}
return {"sWxz":sWxz, "sWhz":sWhz, "sbz":sbz,
"sWxr":sWxr, "sWhr":sWhr, "sbr":sbr,
"sWxn":sWxn, "sWhn":sWhn, "sbn":sbn}
def BiGRU_propagation(func, u, weights, weight_decay, learn_flag, propagation_func=tanh, merge="concat", input_dictionary=False, **params):
h_0 = params.get("h_0")
if h_0 is None:
h_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
h_0[dr] = np.zeros((u.shape[0], weights["Whz"][dr].shape[0]))
else:
h_0[dr] = np.zeros((u[dr].shape[0], weights["Whz"][dr].shape[0]))
c_0 = params.get("c_0")
if c_0 is None:
c_0 = {}
for dr in ["f", "r"]:
if not input_dictionary:
c_0[dr] = np.zeros((u.shape[0], weights["Whz"][dr].shape[0]))
else:
c_0[dr] = np.zeros((u[dr].shape[0], weights["Whz"][dr].shape[0]))
# GRU
h_n_merge, hs_merge, h_n, hs, zs, rs, ns = func(u, h_0,
weights["Wxz"], weights["Whz"], weights["bz"],
weights["Wxr"], weights["Whr"], weights["br"],
weights["Wxn"], weights["Whn"], weights["bn"],
propagation_func, merge, input_dictionary)
# GRU最下層以外
if params.get("return_hs"):
z = hs_merge
# GRU最下層
else:
z = h_n_merge
# 重み減衰対応
weight_decay_r = 0
if weight_decay is not None:
for dr in ["f", "r"]:
weight_decay_r += weight_decay["func"](weights["Wxz"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whz"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxr"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whr"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Wxn"][dr], **weight_decay["params"])
weight_decay_r += weight_decay["func"](weights["Whn"][dr], **weight_decay["params"])
return {"u":u, "z":z, "h_n_merge":h_n_merge, "hs_merge":hs_merge, "h_0":h_0, "h_n":h_n, "hs":hs, "zs":zs, "rs":rs, "ns":ns}, weight_decay_r
def BiGRU_back_propagation(back_func, dz, us, weights, weight_decay, calc_du_flag, propagation_func=tanh, return_hs=False, merge="concat", input_dictionary=False, **params):
# GRU最下層以外
if return_hs:
if merge != "dictionary":
dh_n_merge = np.zeros_like(us["h_n_merge"])
else:
dh_n_merge = {"f":np.zeros_like(us["h_n_merge"]["f"]), "r":np.zeros_like(us["h_n_merge"]["r"])}
dhs_merge = dz
# GRU最下層
else:
dh_n_merge = dz
if merge != "dictionary":
dhs_merge = np.zeros_like(us["hs_merge"])
else:
dhs_merge = {"f":np.zeros_like(us["hs_merge"]["f"]), "r":np.zeros_like(us["hs_merge"]["r"])}
# GRUの勾配計算
du, dWxz, dWhz, dbz, dWxr, dWhr, dbr, dWxn, dWhn, dbn, dh_0 = back_func(dh_n_merge, dhs_merge,
us["u"], us["hs"], us["h_0"],
us["zs"], us["rs"], us["ns"],
weights["Wxz"], weights["Whz"], weights["bz"],
weights["Wxr"], weights["Whr"], weights["br"],
weights["Wxn"], weights["Whn"], weights["bn"],
propagation_func, merge, input_dictionary)
# 重み減衰対応
if weight_decay is not None:
for dr in ["f", "r"]:
dWxz[dr] += weight_decay["back_func"](weights["Wxz"][dr], **weight_decay["params"])
dWhz[dr] += weight_decay["back_func"](weights["Whz"][dr], **weight_decay["params"])
dWxr[dr] += weight_decay["back_func"](weights["Wxr"][dr], **weight_decay["params"])
dWhr[dr] += weight_decay["back_func"](weights["Whr"][dr], **weight_decay["params"])
dWxn[dr] += weight_decay["back_func"](weights["Wxn"][dr], **weight_decay["params"])
dWhn[dr] += weight_decay["back_func"](weights["Whn"][dr], **weight_decay["params"])
return {"du":du, "dWxz":dWxz, "dWhz":dWhz, "dbz":dbz, "dWxr":dWxr, "dWhr":dWhr, "dbr":dbr, "dWxn":dWxn, "dWhn":dWhn, "dbn":dbn}
def BiGRU_update_weight(func, du, weights, optimizer_stats, **params):
for dr in ["f", "r"]:
weights["Wxz"][dr], optimizer_stats["sWxz"][dr] = func(weights["Wxz"][dr], du["dWxz"][dr], **params, **optimizer_stats["sWxz"][dr])
weights["Whz"][dr], optimizer_stats["sWhz"][dr] = func(weights["Whz"][dr], du["dWhz"][dr], **params, **optimizer_stats["sWhz"][dr])
weights["bz"][dr], optimizer_stats["sbz"][dr] = func(weights["bz"][dr], du["dbz"][dr], **params, **optimizer_stats["sbz"][dr])
weights["Wxr"][dr], optimizer_stats["sWxr"][dr] = func(weights["Wxr"][dr], du["dWxr"][dr], **params, **optimizer_stats["sWxr"][dr])
weights["Whr"][dr], optimizer_stats["sWhr"][dr] = func(weights["Whr"][dr], du["dWhr"][dr], **params, **optimizer_stats["sWhr"][dr])
weights["br"][dr], optimizer_stats["sbr"][dr] = func(weights["br"][dr], du["dbr"][dr], **params, **optimizer_stats["sbr"][dr])
weights["Wxn"][dr], optimizer_stats["sWxn"][dr] = func(weights["Wxn"][dr], du["dWxn"][dr], **params, **optimizer_stats["sWxn"][dr])
weights["Whn"][dr], optimizer_stats["sWhn"][dr] = func(weights["Whn"][dr], du["dWhn"][dr], **params, **optimizer_stats["sWhn"][dr])
weights["bn"][dr], optimizer_stats["sbn"][dr] = func(weights["bn"][dr], du["dbn"][dr], **params, **optimizer_stats["sbn"][dr])
return weights, optimizer_stats