インスタンス=関数????
つくりながら学ぶ!PyTorchによる発展ディープラーニングという本を読んでいると1-3の転移学習のところにこのような記述がありました。(筆者GitHubで全てのコードが見られます)
# パッケージのimport
import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms
(中略)
# 損失関数の設定
criterion = nn.CrossEntropyLoss()
(中略)
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):
# epochのループ
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch+1, num_epochs))
print('-------------')
# epochごとの学習と検証のループ
for phase in ['train', 'val']:
if phase == 'train':
net.train() # モデルを訓練モードに
else:
net.eval() # モデルを検証モードに
epoch_loss = 0.0 # epochの損失和
epoch_corrects = 0 # epochの正解数
# 未学習時の検証性能を確かめるため、epoch=0の訓練は省略
if (epoch == 0) and (phase == 'train'):
continue
# データローダーからミニバッチを取り出すループ
for inputs, labels in tqdm(dataloaders_dict[phase]):
# optimizerを初期化
optimizer.zero_grad()
# 順伝搬(forward)計算
with torch.set_grad_enabled(phase == 'train'):
outputs = net(inputs)
loss = criterion(outputs, labels) # 損失を計算
_, preds = torch.max(outputs, 1) # ラベルを予測
# 訓練時はバックプロパゲーション
if phase == 'train':
loss.backward()
optimizer.step()
# イタレーション結果の計算
# lossの合計を更新
epoch_loss += loss.item() * inputs.size(0)
# 正解数の合計を更新
epoch_corrects += torch.sum(preds == labels.data)
# epochごとのlossと正解率を表示
epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
epoch_acc = epoch_corrects.double(
) / len(dataloaders_dict[phase].dataset)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
ここで注目していただきたいのが、criterionです。これはnn.CrossEntropyLoss()のインスタンスとして以下のように定義されています。
criterion = nn.CrossEntropyLoss()
そして筆者は関数のようにcriterionを扱っています。
loss = criterion(outputs, labels)
しかしながら、torch.nn.CrossEntropyLossのソースコードを確認してみると、__call__メソッド
の記述はないのです!
では、なぜCrossEntropyLoss()のインスタンスを関数のように扱えるのでしょうか?
この謎を解決したいというのが本記事の趣旨です。
なぜ__call__メソッド
の有無が重要なのかはこちらを参照してください。
クラスの継承について
CrossEntropyLossクラス
のソースコードの冒頭は以下のように書かれています。
class CrossEntropyLoss(_WeightedLoss):
まずPythonでclass
を定義するときに括弧の中に何かを入れているのはどういう意味なのでしょうか。
これはクラスの継承というもので、別のクラスで定義した関数やメソッドをそのまま呼び出すときに使います。(以下の具体例はこちらを引用しました)
#継承
class MyClass:
def hello(self):
print("Hello")
class MyClass2(MyClass):
def world(self):
print("World")
a = MyClass2()
a.hello() # Hello
a.world() # World
ここで注意点なのですが、親クラスと子クラスに同じ名前のメソッドが定義されていた場合には、子クラスのメソッドが上書きされます。これをオーバーライドと言います。
#オーバーライド
class MyClass:
def hello(self):
print("Hello")
class MyClass2(MyClass):
def hello(self): # 親クラスの hello() メソッドを上書き
print("HELLO")
a = MyClass2()
a.hello() # HELLO
そして、子クラスのメソッドに親クラスで定義したものを使いたい!と思ったときに使えるのがsuper()
関数です。
class MyClass1:
def __init__(self):
self.val1 = 123
class MyClass2(MyClass1):
def __init__(self):
super().__init__()
self.val2 = 456
a = MyClass2()
print(a.val1) # 123
print(a.val2) # 456
話を戻すと、CrossEntropyLossクラス
は_WeightedLossクラス
を継承しているということです。ちなみに、CrossEntropyLossのコードをもう少し確認してみると、
class CrossEntropyLoss(_WeightedLoss):
__constants__ = ['ignore_index', 'reduction']
ignore_index: int
def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
reduce=None, reduction: str = 'mean') -> None:
super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)
super(CrossEntropyLoss, self)
とあって上の例とは少し記述が異なっていますが、Python公式を参照すると両者の意味が全く同じであることがわかります。
class C(B):
def method(self, arg):
super().method(arg) # This does the same thing as:
# super(C, self).method(arg)
では、_WeitedLossクラス
の記述を見てみます。
class _WeightedLoss(_Loss):
これより、_WeitedLoss
は_Loss
を継承していることがわかります。
では、_WeitedLossクラス
の記述を見てみます。
class _Loss(Module):
これより、_Loss
はModule
を継承していることがわかります。
では、Moduleクラス
の記述を見てみます。
class Module:
Module
は何も継承していません!というわけでModule
の内容から確認していきます。
torch.nn.Module
_Lossクラス
にModuleクラスの__init__メソッド
が継承されているのでこれだけ確認してみます。
#注:全てのコードは載せていません
from collections import OrderedDict, namedtuple
class Module:
_version: int = 1
training: bool
dump_patches: bool = False
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
ここではたくさんのOrderedDict()
が定義されているのがわかります。OrederedDict()
について詳しくはこちらを参照していただきたいのですが、簡単に言うと名前の通り、**「順番が保持された(Ordered)、空の辞書(dict)」**のことです。つまり、このクラスでは空の辞書がたくさん定義されているというだけです。
そしてですね、実はここで問題の__call__メソッド
が定義されているのです!
def _call_impl(self, *input, **kwargs):
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in itertools.chain(
_global_forward_hooks.values(),
self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in itertools.chain(
_global_backward_hooks.values(),
self._backward_hooks.values()):
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
__call__ : Callable[..., Any] = _call_impl
最終行の__call__ : Callable[..., Any] = _call_impl
で__call__
の内容を_call_impl
としているのでインスタンスを関数のように呼び出すと上記の関数が実行されます。Callable[..., Any]
のところの意味が分からない方はこちらが参考になります。またこのコロンは関数アノテーションというもので詳しくはこちらを参照してください。簡単に言うと、「関数の引数や返り値にアノテーション(注釈)となる式を記述」しているだけです。
このコードの意味についてはこちらの記事で追っていきたいと思います。
Moduleクラス
には上記以外にもいくつかメソッドが定義されているので必要があれば確認してください。
以下は流し読みで結構です。
torch.nn._Loss
_WeightedLossクラス
に_Lossクラスの__init__メソッド
が継承されているのでこれを確認してみます。
class _Loss(Module):
reduction: str
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super(_Loss, self).__init__()
if size_average is not None or reduce is not None:
self.reduction = _Reduction.legacy_get_string(size_average, reduce)
else:
self.reduction = reduction
ここでは新たにself.reduction
を導入しているのがわかります。そしてその値はsize_average
,reduce
の値に左右されるようですね。
torch.nn.__WeightedLoss
CrossEntropyLossクラス
に_WeightedLossクラスの__init__メソッド
が継承されているのでこれを確認してみます。
class _WeightedLoss(_Loss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
self.register_buffer('weight', weight)
ここでweight
の関数アノテーションにOptional[Tensor]
というものが指定されています。こちらの解説がわかりやすいです。簡単に言うと、weight
はTensor型
かNone型
のどちらが入ってもいいという意味です。
本題に戻ります。ここでは新たにself.register_buffer
という関数がありますが、これはModule
クラスで定義されている関数です。以下、ソースコードです。
forward: Callable[..., Any] = _forward_unimplemented
def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
r"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting :attr:`persistent` to ``False``. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module's
:attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (string): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor): buffer to be registered.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
Example::
>>> self.register_buffer('running_mean', torch.zeros(num_features))
"""
if persistent is False and isinstance(self, torch.jit.ScriptModule):
raise RuntimeError("ScriptModule does not support non-persistent buffers")
if '_buffers' not in self.__dict__:
raise AttributeError(
"cannot assign buffer before Module.__init__() call")
elif not isinstance(name, torch._six.string_classes):
raise TypeError("buffer name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
raise KeyError("buffer name can't contain \".\"")
elif name == '':
raise KeyError("buffer name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name))
elif tensor is not None and not isinstance(tensor, torch.Tensor):
raise TypeError("cannot assign '{}' object to buffer '{}' "
"(torch Tensor or None required)"
.format(torch.typename(tensor), name))
else:
self._buffers[name] = tensor
if persistent:
self._non_persistent_buffers_set.discard(name)
else:
self._non_persistent_buffers_set.add(name)
結構長いコードですが、上半分はコードの説明、if文
のelse
よりも上の部分はエラーの設定をしているだけなので説明は割愛します。そしてelse
ではdict型
のself._buffers
に要素を入れていますね。つまり、WeightedLossクラス
を定義することで以下のようになりました。
self._buffer = {'weight': weight} #右のweightはTensor型もしくはNone型
torch.nn.CrossEntropyLoss
ようやくここまで疑問点まで帰ってこられました。以下がソースコードです。長いコメントアウトがありますが、一応すべて引用します。
class CrossEntropyLoss(_WeightedLoss):
r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.
It is useful when training a classification problem with `C` classes.
If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
assigning weight to each of the classes.
This is particularly useful when you have an unbalanced training set.
The `input` is expected to contain raw, unnormalized scores for each class.
`input` has to be a Tensor of size either :math:`(minibatch, C)` or
:math:`(minibatch, C, d_1, d_2, ..., d_K)`
with :math:`K \geq 1` for the `K`-dimensional case (described later).
This criterion expects a class index in the range :math:`[0, C-1]` as the
`target` for each value of a 1D tensor of size `minibatch`; if `ignore_index`
is specified, this criterion also accepts this class index (this index may not
necessarily be in the class range).
The loss can be described as:
.. math::
\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
= -x[class] + \log\left(\sum_j \exp(x[j])\right)
or in the case of the :attr:`weight` argument being specified:
.. math::
\text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right)
The losses are averaged across observations for each minibatch. If the
:attr:`weight` argument is specified then this is a weighted average:
.. math::
\text{loss} = \frac{\sum^{N}_{i=1} loss(i, class[i])}{\sum^{N}_{i=1} weight[class[i]]}
Can also be used for higher dimension inputs, such as 2D images, by providing
an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`,
where :math:`K` is the number of dimensions, and a target of appropriate shape
(see below).
Args:
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size `C`
size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when reduce is ``False``. Default: ``True``
ignore_index (int, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. When :attr:`size_average` is
``True``, the loss is averaged over non-ignored targets.
reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
be applied, ``'mean'``: the weighted mean of the output is taken,
``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in
the meantime, specifying either of those two args will override
:attr:`reduction`. Default: ``'mean'``
Shape:
- Input: :math:`(N, C)` where `C = number of classes`, or
:math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
in the case of `K`-dimensional loss.
- Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
K-dimensional loss.
- Output: scalar.
If :attr:`reduction` is ``'none'``, then the same size as the target:
:math:`(N)`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case
of K-dimensional loss.
Examples::
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
"""
__constants__ = ['ignore_index', 'reduction']
ignore_index: int
def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
reduce=None, reduction: str = 'mean') -> None:
super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)
まず__init__メソッド
ではself.ignore_index
という変数が新たに加えられていますね。そしてforward()
という関数もまた定義されています。しかし、__call__メソッド
はModuleクラス
以降定義されていません。よって、Moduleクラス
の__call__メソッド
がCrossEntropyLossクラス
のインスタンスが関数のように使えていた正体だったのです。