- 2020/6/9 デシリアライズについて追記
概要
非常に簡単。tf.distribute.MirroredStrategy
のスコープ内でネットワークを構築するだけ。変更は数行ですむ。
Keras APIを用いた実装例
ここでは簡単な例として隠れ層が1層のみからなるシンプルなネットワークを構築しています。
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
with tf.distribute.MirroredStrategy().scope():
# このブロックでネットワークを構築する
x = Input(32, dtype=tf.int32, name='x')
hidden = Dense(20, activation='relu', name='hidden')(x)
y = Dense(5, activation='sigmoid', name='y')(hidden)
model = Model(inputs=x, outputs=y)
model.compile(
optimizer=Adam(lr=0.001),
loss='binary_crossentropy',
)
model.fit(
x_train, y_train,
epochs=10,
batch_size=16,
)
デシリアライズ時の注意点
シリアライズして保存されたモデルをロードする場合も、同じようにtf.distribute.MirroredStrategy
のスコープ内でモデルを構築すれば良い。
ただし一点注意が必要。公式ドキュメントには記載がないが、下記のstackoverflowの記事にあるように、おそらくtf.keras.models.load_model()
を使うことはできない制約がある。
そのため、記事の回答にあるように、save_weight()
やload_weight()
を使う必要がある。
with tf.distribute.MirroredStrategy().scope():
model = build_model(...)
model.compile(...)
model.load_weights('model.h5')
model.compile(...)
model.fit(...)
関連
公式なドキュメントは次の通り。
Keras APIをそのまま利用する場合はこの記事で紹介した通りだが、custom training loopを実装している場合などは、さらに考慮すべき点がいくつかある。その場合は上記のドキュメントを参照していただきたい。
multi_gpu_model()
は2020年4月以降に廃止予定とのこと。