Help us understand the problem. What is going on with this article?

[PyTorch]TRANSFER LEARNING FOR COMPUTER VISION TUTORIALで気になったところ

はじめに

TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL(1)のコードでこれ何の処理だろう?って思ったところをちょっとまとめたくて記事を書きました.間違っていたらコメントを頂けると嬉しいです.

optimizer.zero_grad()

train_model関数の中にひっそりとある一行ですが,勾配の蓄積を初期化するため結構大事な関数です.初期化を行わないで勾配を蓄積させると収束しなくなると考えられます.重み$W$の更新を行う際,最急降下法ベースのだと

W = W - \eta \frac{\partial L}{\partial W}

この式の

\frac{\partial L}{\partial W}

この部分が勾配です.$\eta$は学習率(Learning Rate)です.
なので,学習を行うときはoptimizer.zero_grad()が必要です.

set_grad_enabled()

train_model関数のでwith句として呼ばれてる関数です.なくても計算問題ないのではと思い調べた結果,計算グラフの作成(2)において,学習時は順伝播,逆伝播が必要ですが,評価時は逆伝播を使用しないため,計算量を減らす目的があると考えられます.
with句使ってるからメモリの確保とかも絡んでる?

running_loss += loss.item() * inputs.size(0)

train_model関数内にある初めて見るとん?ってなる一行です.そもそも損失関数の定義で

criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, labels)

となっているのが原因ですが,CrossEntropyLossを見てみると引数にreduction='mean'とあります.つまり,デフォルトで平均損失値が返ってくるため,バッチ学習であれば問題ないのですがミニバッチ学習となると,平均値(mean)から合計(sum)に直す必要があります.したがって,

running_loss += loss.item() * inputs.size(0)

と小難しいミニバッチの平均損失値にミニバッチのサンプル数を掛けてもとに戻してるんですね.
ちなみにこんなことをしなくても

criterion = nn.CrossEntropyLoss(reduction='sum')

loss.item() * inputs.size(0)と同じ結果が返ってきます.
なので,

#running_loss += loss.item() * inputs.size(0)
running_loss += loss.item()

と直して,不思議なことをせずに素直に書けますね.
一応ここ(3)でも言及されてます.

追記(2020/3/11)

criterion = nn.CrossEntropyLoss(reduction='sum') で 損失値を求めた後,逆伝播するとnan になるバグ(4)があるそうです.
なのでおとなしくrunning_loss += loss.item() * inputs.size(0)こっちを使ったほうがいいかも

さいごに

そのうち気になるところが増えたら続き書きます.

参考資料

(1) TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL
(2) Why we need torch.set_grad_enabled(False) here?
(3) Issue about updating training loss #2
(4) torch.nn.CrossEntropyLoss with "reduction" sum/mean is not deterministic on segmentation outputs / labels #17350

Haaamaaaaa
琴葉茜と暮らすことを夢見て技術力を欲してる人です. 機械学習のモデル構築とUMLでモデリングするの楽しいです.
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away