はじめに
PyTorchで自作の損失関数の書き方、使い方を説明します。私が使っているPython, PyTorchの環境は以下の通りです。
動作環境
Python 3.7.9
torch 1.6.0+cu101
PyTorch標準の損失関数に倣った書き方
PyTorchに元々ある__torch.nn.MSELoss__や__torch.nn.CrossEntropyLoss__等に倣った書き方です。クラスとして損失関数を定義します。
class CustomLoss(nn.Module):
def __init__(self):
super().__init__()
# 初期化処理
# self.param = ...
def forward(self, outputs, targets):
'''
outputs: 予測結果(ネットワークの出力)
targets: 正解
'''
# 損失の計算
# loss = ...
return loss
使い方
PyTorch標準の損失関数と同じ使い方です。
'''
model: ネットワーク
inputs: ネットワークの入力
targets: 正解
'''
criterion = CustomLoss()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
単純な書き方
単純に関数として定義しても大丈夫です。
def CustomLoss(outputs, targets):
'''
outputs: 予測結果
targets: 正解
'''
# 損失の計算
# loss = ...
return loss
使い方
'''
model: ネットワーク
inputs: ネットワークの入力
targets: 正解
'''
outputs = model(inputs)
loss = CustomLoss(outputs, targets)
loss.backward()
backward(誤差逆伝播)に関する疑問
誤差逆伝播に使う勾配はどうなってるの?自作の損失関数の勾配は計算しなくていいの?という疑問のある人向けの説明です。
結論から言うと、自分で勾配を計算する必要はありません。ライブラリが自動的に計算してくれます。
PyTorchのforwardとbackwardについて詳しく紹介しているサイトがありましたので、掲載しておきます。
Pytorchの基礎 forwardとbackwardを理解する
https://zenn.dev/hirayuki/articles/bbc0eec8cd816c183408
具体例を見ていきましょう。
入力を二乗する関数を作ってtensorを入力してみます。
import torch
def squring(x):
return x**2
x = torch.tensor(1.0, requires_grad=True)
out = squring(x)
print(out)
# tensor(1., grad_fn=<PowBackward0>)
入力するtensorの引数__requires_grad=True__としておくことで、backwardをした際に勾配が自動的に計算されます。
それでは、backwardをして勾配を自動的に計算させましょう。自動的に計算された勾配は、入力したtensorの__x.grad__に格納されています。
out.backward()
print(x.grad)
# tensor(2.)
$x^2$の微分$2x$の結果が__x.grad__に計算されていることがわかります。
このようにPyTorchでは、勾配は自動的に計算されるので、自分で計算する必要はないわけです。
参考資料
Pytorchの基礎 forwardとbackwardを理解する
https://zenn.dev/hirayuki/articles/bbc0eec8cd816c183408