はじめは私も無理だろって思ってました。意味があるのかは知りません。
そもそもバッチサイズ=1にして勾配を累積させればいいんじゃね?
駄目です。バッチサイズ=1でも余裕でOOM
です。
GradientCheckpointingを使う。
上の記事に「Transformerでやるなら、gradient_checkpointing=True
と書いてやるだけでいい」っぽいことが書いてある。
え、生tensorflowでできないっすかねそれ。
gradient-checkpointingでは、tensorflow1.xでの実装が書いてある。これを流用しようと思ったがダメだった。
自分で実装するしかないかぁ。
方針
Functionalモデルを使っているので、モデルを切り分けるのが少々面倒だが、切り分けてしまえばチェックポイントごとに逆伝播していけばいいだけなので、切り分けたものをchunksとしておいておく。
実装
xs = ...
ys = ...
# 順伝播
chunkOuts = []
x = xs
for chunk in chunks[:-1]:
chunkOuts.append(x)
x = chunk(x)
# 変数初期化
total_grads = [tf.zeros_like(var) for var in model.trainable_variables]
# 逆伝播(一番上のレイヤーだけ損失関数を計算する。)
f = x
with tf.GradientTape(persistent=True) as tape:
tape.watch(f)
x = chunks[-1](x)
x = tf.keras.losses.sparse_categorical_crossentropy(ys, x)
grads = tape.gradinet(x, model.trainable_variables, None, "zero")
next_grads = tape.gradient(x, f)
loss = tf.reduce_mean(x).numpy()
total_grads = [tg + g for tg, g in zip(total_grads, grads)]
# 逆伝播(残りのレイヤー)
for i, chunk in enumerate(reversed(chunks[:-1])):
f = chunkOuts[~i]
with tf.GradientTape(persistent=True) as tape:
tape.watch(f)
x = chunk(f)
grads = tape.gradient(x, model.trainable_variables, next_grads)
next_grads = tape.gradient(x, f, next_grads)
total_grads = [tg + g for tg, g in zip(total_grads, grads)]
optimizer.apply_gradients(zip(total_grads, models["trainer"].trainable_variables))
print(loss)
こんな感じで実装できた。
まとめ
- gradient-checkpointingをtensorflowでするのは案外簡単
- tapeは
persistent=True
にしないと2回以上勾配を計算できないのに躓いた