torch.no_gradはテンソルの勾配の計算を不可にするContext-managerだ。
テンソルの勾配の計算を不可にすることでメモリの消費を減らす事が出来る。
このモデルでは、計算の結果毎にrequires_grad = Falseを持っている。
インプットがrequires_grad=Trueであろうとも。
このContext managerは、ローカルスレッドだ。他のスレッドの計算には影響を及ぼさない。
関数も@torch.no_grad()デコレーターを使用して返り値requires_grad=Falseに出来る。
つまり
with torch.no_grad():
のネストの中で定義した変数は、自動的にrequires_grad=Falseとなる。
以下のようにwith torch.no_grad()か、@torch.no_grad()を使用すると
import torch
x = torch.tensor([1.0], requires_grad=True)
y = None
with torch.no_grad():
y = x * 2
# y.requires_grad = False
@torch.no_grad()
def doubler(x):
return x * 2
z = doubler(x)
# z.requires_grad = False
import torch
x = torch.tensor([1.0], requires_grad=True)
y = x * 2
# y.requires_grad = True
def doubler(x):
return x * 2
z = doubler(x)
# z.requires_grad = True
参考
https://pytorch.org/docs/stable/generated/torch.no_grad.html