21
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchで自作の損失関数を使う

Last updated at Posted at 2020-11-26

はじめに

PyTorchで自作の損失関数の書き方、使い方を説明します。私が使っているPython, PyTorchの環境は以下の通りです。

動作環境

Python 3.7.9
torch 1.6.0+cu101

PyTorch標準の損失関数に倣った書き方

PyTorchに元々ある__torch.nn.MSELoss__や__torch.nn.CrossEntropyLoss__等に倣った書き方です。クラスとして損失関数を定義します。

class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # 初期化処理
        # self.param = ... 

    def forward(self, outputs, targets):
        '''
        outputs: 予測結果(ネットワークの出力)
     targets: 正解
        '''
        # 損失の計算
        # loss = ...
        return loss

使い方

PyTorch標準の損失関数と同じ使い方です。

'''
model: ネットワーク
inputs: ネットワークの入力
targets: 正解
'''
criterion = CustomLoss()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()

単純な書き方

単純に関数として定義しても大丈夫です。

def CustomLoss(outputs, targets):
    '''
    outputs: 予測結果
   targets: 正解
    '''
    # 損失の計算
    # loss = ...
    return loss

使い方

'''
model: ネットワーク
inputs: ネットワークの入力
targets: 正解
'''
outputs = model(inputs)
loss = CustomLoss(outputs, targets)
loss.backward()

backward(誤差逆伝播)に関する疑問

誤差逆伝播に使う勾配はどうなってるの?自作の損失関数の勾配は計算しなくていいの?という疑問のある人向けの説明です。

結論から言うと、自分で勾配を計算する必要はありません。ライブラリが自動的に計算してくれます。

PyTorchのforwardとbackwardについて詳しく紹介しているサイトがありましたので、掲載しておきます。
Pytorchの基礎 forwardとbackwardを理解する
https://zenn.dev/hirayuki/articles/bbc0eec8cd816c183408

具体例を見ていきましょう。
入力を二乗する関数を作ってtensorを入力してみます。

import torch

def squring(x):
    return x**2

x = torch.tensor(1.0, requires_grad=True)
out = squring(x)
print(out)
# tensor(1., grad_fn=<PowBackward0>)

入力するtensorの引数__requires_grad=True__としておくことで、backwardをした際に勾配が自動的に計算されます。

それでは、backwardをして勾配を自動的に計算させましょう。自動的に計算された勾配は、入力したtensorの__x.grad__に格納されています。

out.backward()
print(x.grad)
# tensor(2.)

$x^2$の微分$2x$の結果が__x.grad__に計算されていることがわかります。
このようにPyTorchでは、勾配は自動的に計算されるので、自分で計算する必要はないわけです。

参考資料

Pytorchの基礎 forwardとbackwardを理解する
https://zenn.dev/hirayuki/articles/bbc0eec8cd816c183408

21
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?