0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

お題は不問!Qiita Engineer Festa 2024で記事投稿!
Qiita Engineer Festa20242024年7月17日まで開催中!

6GiBVRAMで4300万パラメータのモデルをバッチサイズ64で学習する

Last updated at Posted at 2024-06-22

はじめは私も無理だろって思ってました。意味があるのかは知りません。

そもそもバッチサイズ=1にして勾配を累積させればいいんじゃね?

駄目です。バッチサイズ=1でも余裕でOOMです。

GradientCheckpointingを使う。

上の記事に「Transformerでやるなら、gradient_checkpointing=Trueと書いてやるだけでいい」っぽいことが書いてある。

え、生tensorflowでできないっすかねそれ。
gradient-checkpointingでは、tensorflow1.xでの実装が書いてある。これを流用しようと思ったがダメだった。

自分で実装するしかないかぁ。

方針

Functionalモデルを使っているので、モデルを切り分けるのが少々面倒だが、切り分けてしまえばチェックポイントごとに逆伝播していけばいいだけなので、切り分けたものをchunksとしておいておく。

実装

xs = ...
ys = ...

# 順伝播
chunkOuts = []
x = xs
for chunk in chunks[:-1]:
    chunkOuts.append(x)
    x = chunk(x)

# 変数初期化
total_grads = [tf.zeros_like(var) for var in model.trainable_variables]
# 逆伝播(一番上のレイヤーだけ損失関数を計算する。)
f = x
with tf.GradientTape(persistent=True) as tape:
    tape.watch(f)
    x = chunks[-1](x)
    x = tf.keras.losses.sparse_categorical_crossentropy(ys, x)
grads = tape.gradinet(x, model.trainable_variables, None, "zero")
next_grads = tape.gradient(x, f)
loss = tf.reduce_mean(x).numpy()
total_grads = [tg + g for tg, g in zip(total_grads, grads)]

# 逆伝播(残りのレイヤー)
for i, chunk in enumerate(reversed(chunks[:-1])):
    f = chunkOuts[~i]
    with tf.GradientTape(persistent=True) as tape:
        tape.watch(f)
        x = chunk(f)
    grads = tape.gradient(x, model.trainable_variables, next_grads)
    next_grads = tape.gradient(x, f, next_grads)
    total_grads = [tg + g for tg, g in zip(total_grads, grads)]

optimizer.apply_gradients(zip(total_grads, models["trainer"].trainable_variables))
print(loss)

こんな感じで実装できた。

まとめ

  • gradient-checkpointingをtensorflowでするのは案外簡単
  • tapeはpersistent=Trueにしないと2回以上勾配を計算できないのに躓いた
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?