9
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Tensorflow2のMirroredStrategyを使って複数GPUで計算する

Last updated at Posted at 2020-01-15
  • 2020/6/9 デシリアライズについて追記

概要

非常に簡単。tf.distribute.MirroredStrategyのスコープ内でネットワークを構築するだけ。変更は数行ですむ。

Keras APIを用いた実装例

ここでは簡単な例として隠れ層が1層のみからなるシンプルなネットワークを構築しています。

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam


with tf.distribute.MirroredStrategy().scope():

    # このブロックでネットワークを構築する
    x = Input(32, dtype=tf.int32, name='x')
    hidden = Dense(20, activation='relu', name='hidden')(x)
    y = Dense(5, activation='sigmoid', name='y')(hidden)
    model = Model(inputs=x, outputs=y)
    model.compile(
        optimizer=Adam(lr=0.001),
        loss='binary_crossentropy',
    )

model.fit(
    x_train, y_train,
    epochs=10,
    batch_size=16,
)

デシリアライズ時の注意点

シリアライズして保存されたモデルをロードする場合も、同じようにtf.distribute.MirroredStrategyのスコープ内でモデルを構築すれば良い。

ただし一点注意が必要。公式ドキュメントには記載がないが、下記のstackoverflowの記事にあるように、おそらくtf.keras.models.load_model()を使うことはできない制約がある。

そのため、記事の回答にあるように、save_weight()load_weight()を使う必要がある。

with tf.distribute.MirroredStrategy().scope():
    model = build_model(...)
    model.compile(...)
    model.load_weights('model.h5')
    model.compile(...)
model.fit(...)

関連

公式なドキュメントは次の通り。

Keras APIをそのまま利用する場合はこの記事で紹介した通りだが、custom training loopを実装している場合などは、さらに考慮すべき点がいくつかある。その場合は上記のドキュメントを参照していただきたい。

multi_gpu_model()は2020年4月以降に廃止予定とのこと。

9
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?