42
25

More than 3 years have passed since last update.

PyTorchのtorch.no_grad()とは何か(超個人的メモ)

Last updated at Posted at 2021-02-20

torch.no_gradはテンソルの勾配の計算を不可にするContext-managerだ。

テンソルの勾配の計算を不可にすることでメモリの消費を減らす事が出来る。

このモデルでは、計算の結果毎にrequires_grad = Falseを持っている。
インプットがrequires_grad=Trueであろうとも。

このContext managerは、ローカルスレッドだ。他のスレッドの計算には影響を及ぼさない。

関数も@torch.no_grad()デコレーターを使用して返り値requires_grad=Falseに出来る。

つまり
with torch.no_grad():
のネストの中で定義した変数は、自動的にrequires_grad=Falseとなる。

以下のようにwith torch.no_grad()か、@torch.no_grad()を使用すると

import torch

x = torch.tensor([1.0], requires_grad=True)
y = None
with torch.no_grad():
    y = x * 2
# y.requires_grad = False

@torch.no_grad()
def doubler(x):
    return x * 2

z = doubler(x)
# z.requires_grad = False
import torch

x = torch.tensor([1.0], requires_grad=True)

y = x * 2
# y.requires_grad = True

def doubler(x):
    return x * 2

z = doubler(x)
# z.requires_grad = True

参考
https://pytorch.org/docs/stable/generated/torch.no_grad.html

42
25
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
42
25