8
4

More than 1 year has passed since last update.

with torch.no_grad(): とは何か

Last updated at Posted at 2022-04-09

trainingの後で、モデルをテストするパートに書かれるアレです。

with torch.no_grad():
    for data, target in loaders['test']:
        ...

こんな感じ.

この記事では、

  1. torch.no_grad()の役割
  2. withとは何か(pythonの文法)

について説明します。

1. torch.no_grad() について

  • そもそもTensor型の変数とは
    ndarrayのように行列やベクトルを扱えることに加えて、GPUを使え、勾配情報を保持することができる変数のこと。requires_grad=Trueで勾配保持、=Falseで保持しない選択ができる。
x = torch.ones(2,2,requires_grad=True)
print(x)

---出力---
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)

さて、勾配情報はモデルのトレーニング時に重みの更新に使われるが、テスト時には必要ない。そこでwith torch.no_grad()ブロックで定義されたテンソルは全て、=Falseとされる。これはメモリ消費量減に貢献する。

ここで少し寄り道。よく似たものに、optimizer.zero_grad()というものがある。PyTorchでは、次のバッチの勾配を計算するときも前の勾配を保持している。即ち、
$$今回のgrad = 前計算したgrad + 今計算したgrad$$ となっている。
RNNを除いて、これをバッチ毎に勾配を0に初期化することで、正しい計算ができる。

話を戻そう。
ところで、withってなんだ???

2. withとはなにか

この記事を参考に書きます。

with文は、煩雑なコードになりがちなtry/finally文をwith一つでスッキリさせるために生まれた。また、context managerによって定義されたメソッドでラップするために使うことができる。ファイル操作や通信関係のプログラムでよく見るのは、close()といった後処理を書く必要がないためにエラーやバグを減らすことができるからなのだ。
 ここでcontext managerは、__enter__()__exit__()メソッドを持っていて、それぞれwithブロックに入るときと、そこから抜けるときに実行される。

では、それらのメソッドはどんな処理をするのか???

今後追記します。

8
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
4