LoginSignup
15
18

More than 5 years have passed since last update.

ゼロから作るDeep Learning 〜Softmax-with-Lossレイヤ〜

Posted at

はじめに

「ゼロから作るDeep Learning」の解説記事です。今回はSoftmax-with-Lossレイヤの概要と逆伝播の計算方法、Pythonの実装方法について説明していきます。

Softmax-with-Lossとは?

ニューラルネットワークで分類の問題の推論処理(例.手書き数字の推定)を行う際に、入力データ(例.画像データ)をネットワークに入力して、出力(例.要素数10のone-hotベクトル)が得られます。この出力の最大値がこのニューラルネットワークの推論する値です。

Softmax-with-Lossとは、ニューラルネットワークが学習する際に、出力と教師データを使ってLossを計算する方法です。その名の通り、出力のSoftmaxを計算して、教師データとのLossを計算します。後述しますが、逆伝播の出力が「Softmaxの出力ー教師データ」となるのがポイントです。(逆に言うと、単純な差分として計算できるように、Lossが定義されています)

ニューラルネットワークの出力をa、Softmaxの出力をy、教師データをt、LossをLとすると計算式は以下のようになります。tの要素は正解の一つが1で他は0です。

y_k = \frac{\exp(a_k)}{\sum_{i=1}^n \exp(a_i)}, \\
L = -\sum_k t_k \log(y_k)

Softmax関数は出力の合計が1になるのが特徴です。これは推論となる出力の正しさを確率として示せることを意味しています。LossはCross Entropy Errorと言うものです。正解の確率を対数として示したものです。先日、誤差逆伝播法について記載しましたが(記事へのリンク)、ここで書いているLがこのLossに該当します。

Lossの計算はあくまでニューラルネットワークの学習の際に行われるものであり、推論の時は不要な処理です。最大となるaを利用すれば推論結果が分かるためです。

逆伝播が出力と教師データの差分になる理由

ここでは逆伝播について解説します。本書では計算グラフを書いて各ノードごとに計算を進めていますが、この記事ではLをaの式で示して、一発で微分を計算して求めます。まずLとaの関係式は以下のようになります。

\begin{align}
L &= -\sum_k t_k \log \biggl (\frac{\exp(a_k)}{\sum_i \exp(a_i)} \biggr ) \\
&= -\sum_k t_k a_k + \sum_k t_k \log \biggl (\sum_i \exp(a_i) \biggr ) \\
&= -\sum_k t_k a_k + \log \biggl (\sum_i \exp(a_i) \biggr ) \\
\end{align}

最後の式変形は、logの中身は走る変数のkに依存しないこと、及びtの合計は1であることを利用しています。この式から、aの一要素の偏微分を計算します。

\begin{align}
\frac{\partial L}{\partial a_k}
&= -t_k + \frac{\exp(a_k)}{\sum_i \exp(a_i)} \\
&= y_k - t_k
\end{align}

以上により、Softmax-with-Lossレイヤの逆誤差伝播は「Softmaxの出力ー教師データ」であることが示されました。

Pythonによる実装

以上の内容をPythonで記述すると以下のようになります。
上述の説明だと、

  • データは1つのみでバッチ処理はしない
  • 教師データはone-hot-vectorでありindexではない

ことを前提としていますが、本書のコードは汎用性を持たせるためにバッチ処理もindex形式の教師データに対しても計算できるようになっています。この前提の違いをカバーするために、本書のコードの中に補足コメントを追記しています。

def softmax(x):
    if x.ndim == 2: # バッチ処理をする場合
        x = x.T # 縦方向(axis=0)にMAX,SUMを計算するため一旦転置する
        x = x - np.max(x, axis=0) # オーバーフロー対策
        y = np.exp(x) / np.sum(np.exp(x), axis=0)
        return y.T 
    x = x - np.max(x) # オーバーフロー対策
    return np.exp(x) / np.sum(np.exp(x))

def cross_entropy_error(y, t):
    if y.ndim == 1: # データは横に並べる
        t = t.reshape(1, t.size)
        y = y.reshape(1, y.size)

    # 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換
    if t.size == y.size:
        t = t.argmax(axis=1)
    batch_size = y.shape[0] # 縦方向の数=バッチ処理するデータの数
    return -np.sum(np.log(y[np.arange(batch_size), t])) / batch_size
    # np.arangeで各データを取り出す。そのt番目の要素を取り出してlogを計算
    # この計算のためにtはインデックスである必要がある

class SoftmaxWithLoss:
    def __init__(self):
        self.loss = None
        self.y = None # softmaxの出力
        self.t = None # 教師データ

    def forward(self, x, t):
        self.t = t
        self.y = softmax(x)
        self.loss = cross_entropy_error(self.y, self.t)
        return self.loss

    def backward(self, dout=1):
        batch_size = self.t.shape[0]
        if self.t.size == self.y.size: # 教師データがone-hot-vectorの場合
            dx = (self.y - self.t) / batch_size
            # forward()のLossの計算だとsumを計算してbatch_sizeで割っているため、
            # backward()の計算だとbatch_sizeで割って(乗算の逆伝播)、1を掛ける(sumの逆伝播)
        else:
            dx = self.y.copy()
            dx[np.arange(batch_size), self.t] -= 1 # 正解の時だけt=1を引く
            dx = dx / batch_size
        return dx
15
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
18