11
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[PyTorch] CrossEntropyLoss()のインスタンスをなぜ関数のように扱えるのか

Last updated at Posted at 2020-10-30

インスタンス=関数????

つくりながら学ぶ!PyTorchによる発展ディープラーニングという本を読んでいると1-3の転移学習のところにこのような記述がありました。(筆者GitHubで全てのコードが見られます)

1-3_transfer_learning.ipynb
# パッケージのimport
import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

(中略)

# 損失関数の設定
criterion = nn.CrossEntropyLoss()

(中略)

def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):

    # epochのループ
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------')

        # epochごとの学習と検証のループ
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()  # モデルを訓練モードに
            else:
                net.eval()   # モデルを検証モードに

            epoch_loss = 0.0  # epochの損失和
            epoch_corrects = 0  # epochの正解数

            # 未学習時の検証性能を確かめるため、epoch=0の訓練は省略
            if (epoch == 0) and (phase == 'train'):
                continue

            # データローダーからミニバッチを取り出すループ
            for inputs, labels in tqdm(dataloaders_dict[phase]):

                # optimizerを初期化
                optimizer.zero_grad()

                # 順伝搬(forward)計算
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)  # 損失を計算
                    _, preds = torch.max(outputs, 1)  # ラベルを予測
                    
  
                    # 訓練時はバックプロパゲーション
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    # イタレーション結果の計算
                    # lossの合計を更新
                    epoch_loss += loss.item() * inputs.size(0)  
                    # 正解数の合計を更新
                    epoch_corrects += torch.sum(preds == labels.data)

            # epochごとのlossと正解率を表示
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double(
            ) / len(dataloaders_dict[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

ここで注目していただきたいのが、criterionです。これはnn.CrossEntropyLoss()のインスタンスとして以下のように定義されています。

1-3_transfer_learning.ipynb
criterion = nn.CrossEntropyLoss()

そして筆者は関数のようにcriterionを扱っています。

1-3_transfer_learning.ipynb
loss = criterion(outputs, labels)

しかしながら、torch.nn.CrossEntropyLossのソースコードを確認してみると、__call__メソッドの記述はないのです!
では、なぜCrossEntropyLoss()のインスタンスを関数のように扱えるのでしょうか?
この謎を解決したいというのが本記事の趣旨です。
なぜ__call__メソッドの有無が重要なのかはこちらを参照してください。

クラスの継承について

CrossEntropyLossクラスソースコードの冒頭は以下のように書かれています。

torch.nn.modules.loss
class CrossEntropyLoss(_WeightedLoss):

まずPythonでclassを定義するときに括弧の中に何かを入れているのはどういう意味なのでしょうか。
これはクラスの継承というもので、別のクラスで定義した関数やメソッドをそのまま呼び出すときに使います。(以下の具体例はこちらを引用しました)

#継承
class MyClass:
    def hello(self):
        print("Hello")

class MyClass2(MyClass):
    def world(self):
        print("World")

a = MyClass2()
a.hello() # Hello
a.world() # World

ここで注意点なのですが、親クラスと子クラスに同じ名前のメソッドが定義されていた場合には、子クラスのメソッドが上書きされます。これをオーバーライドと言います。

#オーバーライド
class MyClass:
    def hello(self):
        print("Hello")

class MyClass2(MyClass):
    def hello(self):        # 親クラスの hello() メソッドを上書き
        print("HELLO")

a = MyClass2()
a.hello()                   # HELLO

そして、子クラスのメソッドに親クラスで定義したものを使いたい!と思ったときに使えるのがsuper()関数です。

class MyClass1:
    def __init__(self):
       self.val1 = 123

class MyClass2(MyClass1):
    def __init__(self):
        super().__init__()
        self.val2 = 456

a = MyClass2()
print(a.val1) # 123
print(a.val2) # 456

話を戻すと、CrossEntropyLossクラス_WeightedLossクラスを継承しているということです。ちなみに、CrossEntropyLossのコードをもう少し確認してみると、

torch.nn.modules.loss
class CrossEntropyLoss(_WeightedLoss):

__constants__ = ['ignore_index', 'reduction']
    ignore_index: int

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean') -> None:
        super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)

super(CrossEntropyLoss, self)とあって上の例とは少し記述が異なっていますが、Python公式を参照すると両者の意味が全く同じであることがわかります。

公式より
class C(B):
    def method(self, arg):
        super().method(arg)    # This does the same thing as:
                               # super(C, self).method(arg)

では、_WeitedLossクラスの記述を見てみます。

torch.nn.modules.loss
class _WeightedLoss(_Loss):

これより、_WeitedLoss_Lossを継承していることがわかります。
では、_WeitedLossクラスの記述を見てみます。

torch.nn.modules.loss
class _Loss(Module):

これより、_LossModuleを継承していることがわかります。
では、Moduleクラスの記述を見てみます。

torch.nn.modules.module
class Module:

Moduleは何も継承していません!というわけでModuleの内容から確認していきます。

torch.nn.Module

_LossクラスModuleクラス__init__メソッドが継承されているのでこれだけ確認してみます。

torch.nn.modules.module
#注:全てのコードは載せていません
from collections import OrderedDict, namedtuple

class Module:
    _version: int = 1

    training: bool

    dump_patches: bool = False
    
    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

ここではたくさんのOrderedDict()が定義されているのがわかります。OrederedDict()について詳しくはこちらを参照していただきたいのですが、簡単に言うと名前の通り、**「順番が保持された(Ordered)、空の辞書(dict)」**のことです。つまり、このクラスでは空の辞書がたくさん定義されているというだけです。

そしてですね、実はここで問題の__call__メソッドが定義されているのです!

torch.nn.modules.module
def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in itertools.chain(
                _global_forward_hooks.values(),
                self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in itertools.chain(
                        _global_backward_hooks.values(),
                        self._backward_hooks.values()):
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

    __call__ : Callable[..., Any] = _call_impl

最終行の__call__ : Callable[..., Any] = _call_impl__call__の内容を_call_implとしているのでインスタンスを関数のように呼び出すと上記の関数が実行されます。Callable[..., Any]のところの意味が分からない方はこちらが参考になります。またこのコロンは関数アノテーションというもので詳しくはこちらを参照してください。簡単に言うと、「関数の引数や返り値にアノテーション(注釈)となる式を記述」しているだけです。

このコードの意味についてはこちらの記事で追っていきたいと思います。

Moduleクラスには上記以外にもいくつかメソッドが定義されているので必要があれば確認してください。

以下は流し読みで結構です。

torch.nn._Loss

_WeightedLossクラス_Lossクラス__init__メソッドが継承されているのでこれを確認してみます。

torch.nn.modules.loss
class _Loss(Module):
reduction: str

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction

ここでは新たにself.reductionを導入しているのがわかります。そしてその値はsize_average,reduceの値に左右されるようですね。

torch.nn.__WeightedLoss

CrossEntropyLossクラス_WeightedLossクラス__init__メソッドが継承されているのでこれを確認してみます。

torch.nn.modules.loss
class _WeightedLoss(_Loss):
    def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)

ここでweightの関数アノテーションにOptional[Tensor]というものが指定されています。こちらの解説がわかりやすいです。簡単に言うと、weightTensor型None型のどちらが入ってもいいという意味です。

本題に戻ります。ここでは新たにself.register_bufferという関数がありますが、これはModuleクラスで定義されている関数です。以下、ソースコードです。

torch.nn.modules.module
forward: Callable[..., Any] = _forward_unimplemented

    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
        r"""Adds a buffer to the module.

        This is typically used to register a buffer that should not to be
        considered a model parameter. For example, BatchNorm's ``running_mean``
        is not a parameter, but is part of the module's state. Buffers, by
        default, are persistent and will be saved alongside parameters. This
        behavior can be changed by setting :attr:`persistent` to ``False``. The
        only difference between a persistent buffer and a non-persistent buffer
        is that the latter will not be a part of this module's
        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.

        Args:
            name (string): name of the buffer. The buffer can be accessed
                from this module using the given name
            tensor (Tensor): buffer to be registered.
            persistent (bool): whether the buffer is part of this module's
                :attr:`state_dict`.

        Example::

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """
        if persistent is False and isinstance(self, torch.jit.ScriptModule):
            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:
            raise AttributeError(
                "cannot assign buffer before Module.__init__() call")
        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("buffer name should be a string. "
                            "Got {}".format(torch.typename(name)))
        elif '.' in name:
            raise KeyError("buffer name can't contain \".\"")
        elif name == '':
            raise KeyError("buffer name can't be empty string \"\"")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists".format(name))
        elif tensor is not None and not isinstance(tensor, torch.Tensor):
            raise TypeError("cannot assign '{}' object to buffer '{}' "
                            "(torch Tensor or None required)"
                            .format(torch.typename(tensor), name))
        else:
            self._buffers[name] = tensor
            if persistent:
                self._non_persistent_buffers_set.discard(name)
            else:
                self._non_persistent_buffers_set.add(name)

結構長いコードですが、上半分はコードの説明、if文elseよりも上の部分はエラーの設定をしているだけなので説明は割愛します。そしてelseではdict型self._buffersに要素を入れていますね。つまり、WeightedLossクラスを定義することで以下のようになりました。

self._buffer = {'weight': weight} #右のweightはTensor型もしくはNone型

torch.nn.CrossEntropyLoss

ようやくここまで疑問点まで帰ってこられました。以下がソースコードです。長いコメントアウトがありますが、一応すべて引用します。

torch.nn.modules.loss
class CrossEntropyLoss(_WeightedLoss):
    r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.

    It is useful when training a classification problem with `C` classes.
    If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
    assigning weight to each of the classes.
    This is particularly useful when you have an unbalanced training set.

    The `input` is expected to contain raw, unnormalized scores for each class.

    `input` has to be a Tensor of size either :math:`(minibatch, C)` or
    :math:`(minibatch, C, d_1, d_2, ..., d_K)`
    with :math:`K \geq 1` for the `K`-dimensional case (described later).

    This criterion expects a class index in the range :math:`[0, C-1]` as the
    `target` for each value of a 1D tensor of size `minibatch`; if `ignore_index`
    is specified, this criterion also accepts this class index (this index may not
    necessarily be in the class range).

    The loss can be described as:

    .. math::
        \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
                       = -x[class] + \log\left(\sum_j \exp(x[j])\right)

    or in the case of the :attr:`weight` argument being specified:

    .. math::
        \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right)

    The losses are averaged across observations for each minibatch. If the
    :attr:`weight` argument is specified then this is a weighted average:

    .. math::
        \text{loss} = \frac{\sum^{N}_{i=1} loss(i, class[i])}{\sum^{N}_{i=1} weight[class[i]]}

    Can also be used for higher dimension inputs, such as 2D images, by providing
    an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`,
    where :math:`K` is the number of dimensions, and a target of appropriate shape
    (see below).


    Args:
        weight (Tensor, optional): a manual rescaling weight given to each class.
            If given, has to be a Tensor of size `C`
        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
            the losses are averaged over each loss element in the batch. Note that for
            some losses, there are multiple elements per sample. If the field :attr:`size_average`
            is set to ``False``, the losses are instead summed for each minibatch. Ignored
            when reduce is ``False``. Default: ``True``
        ignore_index (int, optional): Specifies a target value that is ignored
            and does not contribute to the input gradient. When :attr:`size_average` is
            ``True``, the loss is averaged over non-ignored targets.
        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
            losses are averaged or summed over observations for each minibatch depending
            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
            batch element instead and ignores :attr:`size_average`. Default: ``True``
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
            be applied, ``'mean'``: the weighted mean of the output is taken,
            ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in
            the meantime, specifying either of those two args will override
            :attr:`reduction`. Default: ``'mean'``

    Shape:
        - Input: :math:`(N, C)` where `C = number of classes`, or
          :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
          in the case of `K`-dimensional loss.
        - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
          K-dimensional loss.
        - Output: scalar.
          If :attr:`reduction` is ``'none'``, then the same size as the target:
          :math:`(N)`, or
          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case
          of K-dimensional loss.

    Examples::

        >>> loss = nn.CrossEntropyLoss()
        >>> input = torch.randn(3, 5, requires_grad=True)
        >>> target = torch.empty(3, dtype=torch.long).random_(5)
        >>> output = loss(input, target)
        >>> output.backward()
    """
    __constants__ = ['ignore_index', 'reduction']
    ignore_index: int

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean') -> None:
        super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)

まず__init__メソッドではself.ignore_indexという変数が新たに加えられていますね。そしてforward()という関数もまた定義されています。しかし、__call__メソッドModuleクラス以降定義されていません。よって、Moduleクラス__call__メソッドCrossEntropyLossクラスのインスタンスが関数のように使えていた正体だったのです。

こちらの記事でCrossEntropyLoss()のインスタンスを関数のように扱った際に何が起こるのかを迫っていきたいと思います!

11
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?