はじめに
ひよこです。生成 AI や LLM などで大活躍している深層学習ですが、では、深層学習の本質とは何でしょうか? これはなかなか難しい問いですが、私なら「深層学習の本質は損失関数の最適化である」と答えます。入力や出力が画像や文章など何でもよい柔軟性や、ネットワークの膨大な重みに蓄えられた知識構造の深遠さに比べると、単なる関数最適化は随分と矮小化されているように見えますが、それでも本質はこちらだと考えます。
$y=f(x)$ の最適化は数学の世界の話に思え、深層学習の凄さを十分に表していないように感じられるかもしれません。しかし、深層学習の凄いところは、$x$ や $y$ に相当する変数部分に画像や文章を当てはめる具体的な方法論が確立されている点です。さらに、どんなに複雑な入出力関係であっても、それを「損失関数の最適化」という単純明快な数学的タスクに置き換えられる点が、深層学習の発展を支えているのです。
損失関数の最適化が本質であるという意見には異論もあるかもしれませんが、少なくとも損失関数の重要性を否定する人は皆無でしょう。というわけで、今回は深層学習で広く用いられる損失関数を、初心者向けに整理します。本記事では、それぞれの名称、数学表現、使用例、そして PyTorch による実装例を示します。
回帰タスク向け損失
回帰タスクでは予測値 $\hat{y}_i$ と実測値 $y_i$ の誤差を定量化します。典型的には L1 ノルムや L2 ノルムに基づく指標が使われます。
ここで注意したい点は、L1 や L2 といったノルムを用いた「損失」には「合計」と「平均」があり、用語の使い方が文脈によって異なりがちです。
- L1 ノルムに基づく合計の絶対誤差を L1 ロスと呼び、その平均が MAE (Mean Absolute Error)
- L2 ノルムに基づく合計の二乗誤差を L2 ロスと呼び、その平均が MSE (Mean Squared Error)
実用上は、MAE と L1 ロス、MSE と L2 ロスをほぼ同義で扱うこともありますが、数学的には平均を取るか否かの違いがある点を念頭に置くと混乱が少なくなります。
🐣 L1 や L2 の "L" は Lebesgue ノルム(ルベーグノルム)に由来します
Mean Squared Error (MSE)
名称: 平均二乗誤差 (MSE)
数学表現:
\text{MSE} = \frac{1}{N}\sum_{i=1}^{N}(y_i - \hat{y}_i)^2
MSE は L2 ロス(合計の二乗誤差)を平均したもの。
使用例: 一般的な回帰タスクで広く用いる
PyTorch での例:
import torch
loss_fn = torch.nn.MSELoss()
loss = loss_fn(pred, target)
Mean Absolute Error (MAE)
名称: 平均絶対誤差 (MAE)
数学表現:
\text{MAE} = \frac{1}{N}\sum_{i=1}^{N}|y_i - \hat{y}_i|
MAE は L1 ロス(合計の絶対誤差)を平均したもの。外れ値に対して MSE より頑健です。
使用例: 外れ値が多い回帰タスク、スパース性を重視したい場面
PyTorch での例:
loss_fn = torch.nn.L1Loss() # PyTorch の L1Loss は MAE を計算
loss = loss_fn(pred, target)
Smooth L1 Loss (Huber Loss)
名称: スムース L1 損失 (ハバーロス)
数学表現:
\text{SmoothL1}(y, \hat{y}) = \begin{cases}
0.5(y-\hat{y})^2 & \text{if } |y-\hat{y}| < \delta \\[6pt]
\delta|y-\hat{y}| - 0.5\delta^2 & \text{otherwise}
\end{cases}
MSE (L2) と MAE (L1) の中間的特性を持ち、外れ値への強さと安定性をバランスさせます。
使用例: 回帰タスク全般での安定した学習
PyTorch での例:
loss_fn = torch.nn.SmoothL1Loss()
loss = loss_fn(pred, target)
分類タスク向けのクロスエントロピー系損失
分類タスクでは、予測確率分布と真のクラス分布の差異を測定するためのクロスエントロピーが基本です。LLM のトークン予測などでも広く用いられます。
Cross Entropy Loss
名称: クロスエントロピー損失 (CE)
数学表現:
\text{CE}(y, \hat{p}) = -\sum_{k=1}^{K} y_k \log(\hat{p}_k)
$y_k$ は one-hot 表現、$\hat{p}_k$ は予測確率。
使用例: K クラス分類タスク一般
PyTorch での例:
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(pred, target) # target はクラスインデックス
Negative Log Likelihood Loss (NLLLoss)
CE と等価な目的を持つが、ログ確率(log softmax 出力)を直接受け取る。
使用例: クロスエントロピーと同様の分類タスク
PyTorch での例:
loss_fn = torch.nn.NLLLoss()
loss = loss_fn(pred_log_probs, target)
Binary Cross Entropy (BCE)
名称: バイナリクロスエントロピー損失
数学表現:
\text{BCE}(y, \hat{p}) = -\frac{1}{N}\sum_{i=1}^{N}[y_i\log(\hat{p}_i) + (1 - y_i)\log(1 - \hat{p}_i)]
2 クラス分類に特化した CE の一種。
使用例: バイナリ分類タスク
PyTorch での例:
loss_fn = torch.nn.BCELoss()
loss = loss_fn(pred, target)
BCEWithLogitsLoss
BCE にシグモイド適用を内部で包含し、数値的安定性を向上させたもの。
使用例: ロジットを直接出力する 2 クラス分類
PyTorch での例:
loss_fn = torch.nn.BCEWithLogitsLoss()
loss = loss_fn(logits, target)
分布間距離・マージンベース損失など
Kullback–Leibler Divergence (KL Divergence)
名称: KL ダイバージェンス
数学表現:
D_{KL}(P \| Q) = \sum_{x} P(x)\log\frac{P(x)}{Q(x)}
確率分布間の差異を測定。LLM 蒸留などで使用。
使用例: 分布合わせ、知識蒸留
PyTorch での例:
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
loss = loss_fn(pred_log_probs, target_probs)
Hinge Loss
名称: ヒンジ損失
数学表現:
\text{Hinge}(y, s) = \max(0, 1 - y \cdot s)
$y \in {+1, -1}$ のラベルでマージンを確保する。SVM 由来。
使用例: マージンベースの分類器
PyTorch での例:
loss_fn = torch.nn.HingeEmbeddingLoss()
loss = loss_fn(pred, target)
🐣 ReLU 的な発想に似ています(歴史的には逆)
Cosine Embedding Loss
名称: コサイン埋め込み損失
数学表現:
\text{CosineEmb}(x_1, x_2, y) = \begin{cases}
1 - \cos(x_1,x_2) & \text{if } y=1\\[6pt]
\max(0,\cos(x_1,x_2)-\text{margin}) & \text{if } y=-1
\end{cases}
埋め込み空間での類似・非類似ペア学習に利用。検索や埋め込み学習で有用。
PyTorch での例:
loss_fn = torch.nn.CosineEmbeddingLoss(margin=0.0)
loss = loss_fn(x1, x2, target)
発展的な応用
VAE (Variational AutoEncoder) では再構成誤差 (MSE や MAE など) と KL ダイバージェンスを併用し、GAN (Generative Adversarial Network) では生成器と識別器が対抗的な損失を持ちます。Stable Diffusion のようなモデルでも複雑な潜在空間と確率的構造を扱う損失が研究されています。
こうした高度なシナリオでは、タスクに応じて複数の損失を併用し、モデル特性に合わせたカスタムロスを導入することが一般的です。
おわりに
本記事では回帰タスクから分類タスク、分布間距離、さらに埋め込み学習まで、代表的な損失関数を紹介しました。損失関数はモデル学習の方向性を定める鍵であり、適切な選択が性能向上につながります。
初心者の方は MSE や CE のような基本的な損失から理解を深め、徐々に用途に応じて他の損失を使い分けてみてください。実務や研究ではここで挙げた以外にも多種多様な損失が存在し、日々新たな提案がなされています。
損失関数を正しく理解し、適切に活用することで、より優れたモデルを構築できる可能性が広がります。