Pytorch の Softmax+CrossEntropy の実装 に関する記事です。以下のチュートリアルを説明していく感じです。「(torch.nnなしで)ニューラルネットワークを構築」 部分の説明です。
PyTorchチュートリアル/[6] torch.nnの解説
Softmax+CrossEntropy の実装 の周りはディープラーニングの入門部分でも、私にとっては最も面白い部分でした。
【関連記事】
MNIST手書き数字のCNN画像認識 - Qiita
CNN 畳み込み層のメモ - Qiita
Softmax+CrossEntropy の実装 - Qiita
機械学習のウォーミングアップ(Numpy) - Qiita
1. 分類問題のニューラルネットワークの概略
例えば、MNISTの数字認識ネットワークは以下のようになります。10個の分類問題となります。
- X: 入力データ、784要素の画像ベクトル
- Y: 線形ノードの出力、10要素のベクトル
- Z: Softmax関数の出力、各要素が $0<z_i<1$、 $\sum z_i=1$ の確率値となる。
- このモデルの出力を $\log z_i$ と考える。全て負の数となるが、一番大きい値が予測値となる。
\begin{align}
\begin{pmatrix}
x_1 \\
x_2 \\
... \\
x_{784} \\
\end{pmatrix}
線形
→
\begin{pmatrix}
y_1 \\
y_2 \\
... \\
y_{10} \\
\end{pmatrix}
Softmax→
\begin{pmatrix}
z_1 \\
z_2 \\
... \\
z_{10} \\
\end{pmatrix}
log→
\begin{pmatrix}
\log z_1 \\
\log z_2 \\
... \\
\log z_{10} \\
\end{pmatrix}
損失→
- \log z_T \\
\end{align}
ここで Softmax は以下のように定義される。
\begin{align}
& \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad\\
& z_i \equiv \frac{e^{y_i}}{\sum_{k=1}^{10} e^{y_k}} 1 \leqq \forall i \leqq 10\\
\\
&この定義より以下が成り立つ。つまり z_i は確率値と考えられる。\\
& 0<z_i<1 \qquad \sum_{i=1}^{10} z_i =1\\
\end{align}
このネットワークの損失計算 (nil)は、教師データ T に基づいて T 番目の要素を選び出すことになります。例えばT=5の場合、最終出力の $ \begin{pmatrix}
\log z_1 \
\log z_2 \
... \
\log z_{10} \
\end{pmatrix} $ から $\log z_5$ を選び出す計算が損失計算となります。
交差エントロピー関数
どこからどこまでを損失関数とみるかは、実装の仕方によって異なってきますが、通常は log計算をするところから最後の $\log z_T$ を選出する nil 関数までを損失計算として、交差エントロピー関数とよびます。しかし Pytorch では指数関数と対数関数を同時に行うことで計算の安定を図っていますので(LogSoftmax 関数)、この場合は nil 関数だけを切り出して損失関数とみています。
logits
logits とは Softmax 関数に通す前のニューラルネットワークの出力です。上の図で言えば $\begin{pmatrix}
y_1 \
y_2 \
... \
y_{10} \
\end{pmatrix} $ に相当します。 Pytorch の cross_entropy 関数のドキュメントなどに出てくるので調べてみました。
ですから、$- \log z_i$ のグラフは以下のようになります。
これは確率値 $z_i$ が正解の1に近づけば近づくほど、$- \log z_i$ はゼロに近づきます。逆に正解から離れて、0に近づけば近づくほど $- \log z_i$ は大きくなります。つまり $- \log z_i$ を正解からの距離(損失)として考えることができます。
2. 対数の計算
PytorchではLogとSoftmaxを一緒に計算しておくことで、計算結果を安定させている、と言われています。
Kerasを勉強した後にPyTorchを勉強して躓いたこと
まずは以下の対数の2公式を確認しておきます。
\begin{align}
&\qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad\\
& \log \frac{x}{y} = \log x - \log y \\
\\
& x = \log ({e^x})
\end{align}
つぎに公式を使ってLogSoftmaxを計算してみましょう。
\begin{align}
LogSoftmax = \log z_i & = \log \Biggl( \frac{e^{y_i}}{\sum_{k=1}^n e^{y_k}} \Biggr) \\
\\
& = \log \bigl( e^{y_i} \bigr) - \log \biggl( {\sum_{k=1}^n e^{y_k}} \biggr) \\
\\
& = y_i - \log \biggl( {\sum_{k=1}^n e^{y_k}} \biggr) \\
\end{align}
3. Pythonによる実装
以下に torch.nn なしの機械学習の全体の実装を見ていきます。上に解説した部分をなぞればあまり難しくはないと思いますが、少し解説を加えていきたいと思います。
PyTorchチュートリアル/[6] torch.nnの解説
以下、 以下のような MNISTデータ x_train, y_train を前提とします。
print(x_train.shape)
print(y_train.shape)
(50000, 784)
(50000,)
3-1. 線形計算
まず、線形計算に必要な、weights(行列)と bias を定義します。
import math
weights = torch.randn(784, 10) / math.sqrt(784)
weights.requires_grad_()
bias = torch.zeros(10, requires_grad=True)
def model(xb):
return log_softmax(xb @ weights + bias)
@ は行列積を表しています。
Pytorch の 自動微分 を使うため requires_grad 属性を True にします。weights の初期値では、 「Xavier の初期値」 を用いているので、math.sqrt(784) で割り算を行っています。
model は上のニューラルネットワークで、損失計算の手前までの計算を行います。入力 xb に対する予測値を出力する計算です。log_softmaxに関しては以下に述べます。
3-2. 活性化関数 LogSoftmax
LogSoftmax の計算は上で行ったことにより、以下のようになります。
def log_softmax(x):
return x - x.exp().sum(-1).log().unsqueeze(-1)
実際の計算はバッチで行います。ここでバッチサイズを64とすると、log_softmaxの計算過程は以下のようになります。
配列 | shape | 説明 |
---|---|---|
x | [64, 10] | バッチサイズ 64、クラス数10 |
x.exp().sum(-1) | [64] | 最終次元の10要素を sum でスカラー化 |
x.exp().sum(-1).log() | [64] | shapeは変化なし |
x.exp().sum(-1).log().unsqueeze(-1) | [64, 1] | 最終次元を追加 |
最後に、x - x.exp().sum(-1).log().unsqueeze(-1) は、 [64, 10] 型配列から [64, 1] 型配列を引くブロードキャスト演算になります。
これまでの定義により、model の予測値 は以下のように計算できます。
bs = 64 # batch size
xb = x_train[0:bs] # a mini-batch from x
preds = model(xb) # predictions
3-3. 損失計算 nil
損失関数 は nil とネーミングされています。損失計算は、最終出力の $ \begin{pmatrix}
\log z_1 \
\log z_2 \
... \
\log z_{10} \
\end{pmatrix} $ から正解 T に対応した $\log z_T$ を選び出す計算となります。
def nll(input, target):
return -input[range(target.shape[0]), target].mean()
loss_func = nll
実際の計算はバッチで行います。ここでバッチサイズを64とすると以下のようになります。
input.shape = torch.Size([64, 10])
target.shape = torch.Size([64])
input は $ \begin{pmatrix}
\log z_1 \
\log z_2 \
... \
\log z_{10} \
\end{pmatrix} $ が64行あるイメージです。
target は正解のリストなので $[n_1, n_2, ...,n_{64}] \qquad 0 \leq n_i \leq 9$ という形です。
ファンシーインデックスを使うと、input[range(target.shape[0]), target] は input の64行の各行からtarget にある正解のインデックスを取り出す操作を表し、結果は64要素のリストになります。ここではそのリストの平均値(スカラー)をとっています。
ファンシーインデックスについて
ここではバッチサイズ=4 として、(4,10)の配列にファンシーインデックスを適用するサンプルプログラムを実行してみます。
import numpy as np
m_np = np.array([[10,11,12,13,14,15,16,17,18,19],
[20,21,22,23,24,25,26,27,28,29],
[30,31,32,33,34,35,36,37,38,39],
[40,41,42,43,44,45,46,47,48,49]])
print(m_np[[0,1,2,3],[5,2,8,6]])
出力は以下の通り。4行のバッチから、それぞれ正解 [5,2,8,6] の要素を取り出しています。
[15 22 38 46]
3-4. 訓練とモデルの予測精度
全体の流れの中で LogSoftmax と 損失計算 がうまく機能していることを確認します。
accuracy はこの model の予測の精度を計るものです。予測値がどの程度、正解とあっているかを計算していきます。
def accuracy(out, yb):
preds = torch.argmax(out, dim=1)
return (preds == yb).float().mean()
out は model の出力ベクトルです。バッチ計算なので、$ \begin{pmatrix}
\log z_1 \
\log z_2 \
... \
\log z_{10} \
\end{pmatrix} $ というベクトルが64行あるイメージです。torch.argmax(out, dim=1) は各行から最大値を持つ要素の index を求める演算です。3番目の要素が最大であれば、3という数字を返します。最大値の $\log z_i$ が、最も確率が1に近い、最大値の $z_i$ です。上の $\log z_i$ のグラフを参照してください。
weights と bias のパラメータが初期値のまま、つまり未学習のままの accuracy を求めてみます。
print(accuracy(preds, yb))
予想通り accuracy は低い値です。ゼロに近いです。
tensor(0.0156)
以下の繰り返し計算を行うことで、構築した model を基に、weights と bias のパラメータを訓練(学習)します。
from IPython.core.debugger import set_trace
bs = 64 # batch size
lr = 0.05 # learning rate
epochs = 10 # how many epochs to train for
for epoch in range(epochs):
for i in range((n - 1) // bs + 1):
# set_trace()
start_i = i * bs
end_i = start_i + bs
xb = x_train[start_i:end_i]
yb = y_train[start_i:end_i]
pred = model(xb)
loss = loss_func(pred, yb)
loss.backward()
with torch.no_grad():
weights -= weights.grad * lr
bias -= bias.grad * lr
weights.grad.zero_()
bias.grad.zero_()
訓練された weights と bias のパラメータを基に、再度予測を行います。
print(loss_func(model(xb), yb), accuracy(model(xb), yb))
accuracy が 0.0156 から 0.8750 まで上がりました。
tensor(0.3731, grad_fn=<NegBackward0>) tensor(0.8750)
結果的にニューラルネットワークの中で、LogSoftmax と 損失計算 がうまく機能していることが確認できました。
今回は以上です。