タイトルの通り、step関数などの非線形関数を噛ませた場合にもbackpropagationをするにはどうしたらいいかという記事です。
答えは結構簡単で、
import torch
class RoundNoGradient(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.round()
@staticmethod
def backward(ctx, g):
return g
などと定義してあげればいいです。
具体的には、下のように使えます。この例では、非線形関数として、四捨五入するroundを用いています。
2つ目の関数とロスを定義
m = torch.nn.Conv2d(2, 3, 3, stride=2)
l = torch.nn.L1Loss()
普通にroundをした場合
input = torch.autograd.Variable(torch.randn(3, 2, 6, 6), requires_grad=True)
x = input.round()
y = m(x)
output = l(x, torch.autograd.Variable(torch.randn(3, 2, 6, 6)))
output.backward()
print(input.grad[:2,:2,:2,:2])
'''
tensor([[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]],
[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]]])
'''
この場合は途中でroundが挟まっているためにbackpropが出来ていない事がわかります。
次に、RoundNoGradientを使った場合を見てみましょう。
RoundNoGradientを使った場合
input = torch.autograd.Variable(torch.randn(3, 2, 6, 6), requires_grad=True)
x = RoundNoGradient.apply(input)
y = m(x)
output = l(x, torch.autograd.Variable(torch.randn(3, 2, 6, 6)))
output.backward()
print(input.grad[:2,:2,:2,:2])
'''
tensor([[[[ 0.0046, -0.0046],
[-0.0046, -0.0046]],
[[ 0.0046, -0.0046],
[-0.0046, -0.0046]]],
[[[ 0.0046, 0.0046],
[ 0.0046, 0.0046]],
[[ 0.0046, -0.0046],
[-0.0046, 0.0046]]]])
'''
backprop出来ているのがわかります。
ちなみに、tensorflowの場合は
with graph.gradient_override_map({'Round': 'Identity'}):
とするだけで実現できます。
参考