trainingの後で、モデルをテストするパートに書かれるアレです。
with torch.no_grad():
for data, target in loaders['test']:
...
こんな感じ.
この記事では、
- torch.no_grad()の役割
- withとは何か(pythonの文法)
について説明します。
1. torch.no_grad() について
- そもそもTensor型の変数とは
ndarrayのように行列やベクトルを扱えることに加えて、GPUを使え、勾配情報を保持することができる変数のこと。requires_grad=True
で勾配保持、=False
で保持しない選択ができる。
x = torch.ones(2,2,requires_grad=True)
print(x)
---出力---
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
さて、勾配情報はモデルのトレーニング時に重みの更新に使われるが、テスト時には必要ない。そこでwith torch.no_grad()
ブロックで定義されたテンソルは全て、=False
とされる。これはメモリ消費量減に貢献する。
ここで少し寄り道。よく似たものに、optimizer.zero_grad()
というものがある。PyTorchでは、次のバッチの勾配を計算するときも前の勾配を保持している。即ち、
$$今回のgrad = 前計算したgrad + 今計算したgrad$$ となっている。
RNNを除いて、これをバッチ毎に勾配を0に初期化することで、正しい計算ができる。
話を戻そう。
ところで、with
ってなんだ???
2. withとはなにか
この記事を参考に書きます。
with
文は、煩雑なコードになりがちなtry/finally文をwith一つでスッキリさせるために生まれた。また、context managerによって定義されたメソッドでラップするために使うことができる。ファイル操作や通信関係のプログラムでよく見るのは、close()といった後処理を書く必要がないためにエラーやバグを減らすことができるからなのだ。
ここでcontext managerは、__enter__()
と__exit__()
メソッドを持っていて、それぞれwithブロックに入るときと、そこから抜けるときに実行される。
では、それらのメソッドはどんな処理をするのか???
今後追記します。