pytorchのregister_forward_hook()の設定はモデルをdeepcopyしても引き継がれない
気づくのに丸1日費やしたため,躓く人を減らすために書く.
何をしていたか
とある深層学習モデルの研究のために,一度学習し終わったモデルをcopyライブラリのdeepcopy関数を使って複製し,モデルを再び学習させようとしていた(ファインチューニング).しかしモデルの一回目の学習に問題は起きないものの,モデル複製後の学習は損失関数が発散して学習ができなかった.
原因
どうやらモデル複製前に設定していたregister_forward_hookが解除されているらしい.register_forward_hookとはモデルの中間層の入出力を引数にした任意の関数を設定できるというものである.私のプログラムではこれを利用してモデルの中間層の出力を抽出し,損失関数の計算に使っていた.
しかし,モデルをdeepcopyするとregister_forward_hookの設定が消えるようで,保存されるはずの中間層の出力が保存されず,誤差逆伝搬によって下手な勾配が入ってしまい損失関数が発散したのだと思われる.
対処法
deepcopy後にregister_forward_hookを再び設定した.deepcopyではなくcloneならこの問題は起きなかったりするのかも.試してないので分からない.
結論
deepcopyを闇雲に信頼してはいけない.とりあえずdeepcopyしときゃ大体いい感じにコピーされるやろ,という考えが甘かった.