ありがたいことに、世にTensorFlow2.xが公開されてずいぶん経ちました。
これまでv1.xで随分もがいていましたが、v2.xではデコレータで簡単にtf的なグラフに変更できるようになるなど大変素晴らしい機能がたくさん実装され、我々開発者もスピーディにモデルの学習、評価等を行えるようになりました。
しかし一部にv2.xとの互換性の問題からそのままv1.xを使用しなければならない人もいますよね(何を隠そう私のことです)。
今や更新されるドキュメントはみんな2.xについて言及していて1.xのドキュメントは増えにくく、適切な情報にたどり着けなくなってきています。
私は今(2020/11/4 17:10現在)、実際v1.xでパラメータを部分的に読み込む場合の処理についてヒットするまで時間がかかってしまいました。
単にResNetとかMobileNetとか、公開されているモデルをそのまま使用する場合なら計算グラフ全部で読み込んでしまえるのでそこまで苦労しないでしょうが、事前に学習したResNetを画像エンコーダとして後続の自作ネットワークに使用したい...という場合には部分的にパラメータの読み込みをする必要があります。
今後v1.xで部分的にパラメータの読み込みをする方へむけて(何を隠そう私のことです)、タイトルに記載の通り部分的なパラメータ読み込みの方法について記録しておきます。
やること
原則、以下のコードで可能です。
...
with tf.Session() as sess:
saver = tf.train.Saver({'読み込みたいモデルのノード名': そのノード名をつけたtf.Variableの変数, ...})
saver.restore(sess, 'path/to/checkpoint')
ただし、「tf.train.Saverに渡すノード名と変数の辞書どうやって作ればいいんだよ!」となるので、その場合は
variables = tf.trainable_variables()
restore_variables = {}
for v in variables:
if 'モデルの名前空間' in v.name:
restore_variables[v.name] = v
とすることで、現在使用しているノードから特定のノードだけ取り出して辞書に入れることができます。
学習する時、勝手に:0
とかが付与されることがあるので、その場合は
fixed_name = v.name[:-2]
restore_variables[fixed_name] = v
みたいに対応することで、読み込み可能になります。
補足
事前学習したときのcheckpointにそれぞれ変数名がどのように保存されているのか確認する場合は、tensorflow.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file
が便利です。
import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name='path/to/checkpoint', tensor_name='', all_tensors=False)
# beta1_power (DT_FLOAT) []
# beta2_power (DT_FLOAT) []
# cae/conv0/convolution2d/biases (DT_FLOAT) [64]
# cae/conv0/convolution2d/biases/Adam (DT_FLOAT) [64]
# cae/conv0/convolution2d/biases/Adam_1 (DT_FLOAT) [64]
# cae/conv0/convolution2d/weights (DT_FLOAT) [7,7,3,64]
# cae/conv0/convolution2d/weights/Adam (DT_FLOAT) [7,7,3,64]
# cae/conv0/convolution2d/weights/Adam_1 (DT_FLOAT) [7,7,3,64]
# cae/conv1/convolution2d/biases (DT_FLOAT) [32]
# cae/conv1/convolution2d/biases/Adam (DT_FLOAT) [32]
# cae/conv1/convolution2d/biases/Adam_1 (DT_FLOAT) [32]
# cae/conv1/convolution2d/weights (DT_FLOAT) [5,5,64,32]
# cae/conv1/convolution2d/weights/Adam (DT_FLOAT) [5,5,64,32]
# cae/conv1/convolution2d/weights/Adam_1 (DT_FLOAT) [5,5,64,32]
参考
以下の記事を参考にしました。
https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125