機械学習の世界では、モデルのトレーニングにおいて「損失関数(Loss Function)」が重要な役割を果たします。特に分類問題では、クロスエントロピー損失(Cross Entropy Loss)が広く使われています。この記事では、Rust製の機械学習フレームワーク「Burn」のソースコードを読み解きながら、クロスエントロピー損失の実装と応用について学んでいきます。
Burnは、Rustで書かれた高性能な機械学習フレームワークで、型安全性とゼロコスト抽象化を提供することで知られています。ソースコードを読み解くことで、実装の詳細だけでなく、クロスエントロピー損失関数のバリエーションや応用方法についても理解を深めます。
クロスエントロピー損失とは
クロスエントロピー損失は、予測された確率分布と実際の分布(通常は正解ラベル)の間の差異を測定する関数です。数学的には、以下のように表現されます:
$$
CE(y, p) = -\sum_{i=0}^{C} y_i \log(p_i)
$$
ここで、C
はクラス数、y_i
は正解ラベル(通常は0または1)、p_i
は予測された確率です。
正解クラスと不正解クラスの扱い
クロスエントロピー損失がどのように働くのか、正解クラスと不正解クラスに分けて考えてみましょう:
正解クラスの場合(y_i = 1):
- モデルの予測確率p_iが1に近いほど、log(p_i)は0に近づく
- したがって、
-1 * log(p_i)
も0に近づく(低い損失) - 逆に予測確率が0に近いほど、log(p_i)は負の無限大に近づく
- この場合、
-1 * log(p_i)
は非常に大きな正の値(高い損失)になる
不正解クラスの場合(y_i = 0):
-
y_i = 0
なので、-0 * log(p_i) = 0
- つまり、通常の設定では不正解クラスの予測確率は損失に寄与しない
このメカニズムにより、モデルは正解クラスに高い確率を割り当てるように学習します。特に誤って低い確率を割り当てた場合、大きなペナルティ(損失)を受けるため、学習は効果的に進みます。
Burnにおけるクロスエントロピー損失の設計
Burnのソースコードを見ると、クロスエントロピー損失はCrossEntropyLossConfig
とCrossEntropyLoss
の2つの構造体で実装されていることがわかります。これはBurnのモジュールシステムの特徴で、Config
構造体がパラメータの設定を担当し、実際のLoss
構造体がモデルの学習で使用されます。
CrossEntropyLossConfig
まず、設定を担当するCrossEntropyLossConfig
構造体を見てみましょう:
#[derive(Config, Debug)]
pub struct CrossEntropyLossConfig {
/// Create padded cross entropy.
///
/// Prevents pad tokens from impacting loss calculation.
pub pad_tokens: Option<Vec<usize>>,
/// Create weighted cross-entropy.
///
/// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
///
/// # Pre-conditions
/// - The order of the weight vector should correspond to the label integer assignment.
/// - Targets assigned negative Int's will not be allowed.
pub weights: Option<Vec<f32>>,
/// Create cross-entropy with label smoothing.
///
/// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
/// Alpha = 0 would be the same as default.
pub smoothing: Option<f32>,
/// Create cross-entropy with probabilities as input instead of logits.
///
#[config(default = true)]
pub logits: bool,
}
この構造体から、Burnのクロスエントロピー損失が以下の機能をサポートしていることがわかります:
-
パディングトークンの無視:
pad_tokens
フィールドで指定されたトークンを損失計算から除外できます -
重み付けされた損失:
weights
で各クラスに異なる重みを設定できます -
ラベルスムージング:
smoothing
パラメータを使って、正解ラベルを少しぼかすことができます -
入力形式の選択:
logits
フラグで、入力がロジット(未正規化のスコア)か確率かを指定できます
CrossEntropyLoss
次に、実際の損失関数を実装するCrossEntropyLoss
構造体を見てみましょう:
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct CrossEntropyLoss<B: Backend> {
/// Pad tokens to ignore in the loss calculation.
pub pad_tokens: Option<Vec<usize>>,
/// Weights for cross-entropy.
pub weights: Option<Tensor<B, 1>>,
/// Label smoothing factor.
pub smoothing: Option<f32>,
/// Use logits as input.
pub logits: bool,
}
この構造体は、CrossEntropyLossConfig
で設定された値を保持していますが、weights
がTensor
型になっています。これは、計算をBackendで効率的に行うためでしょう。
実際の損失計算を追う
Burnのコードを見ると、クロスエントロピー損失の計算は主に2つのメソッドで行われています:
通常の損失計算(forward_default
)
ラベルスムージング(後述)を使用しない場合、forward_default
メソッドで計算されます:
fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
let [batch_size] = targets.dims();
let mask = self.padding_mask(&targets);
let tensor = log_softmax(logits, 1);
let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
match &self.weights {
Some(weights) => {
let weights = weights.clone().gather(0, targets);
let tensor = tensor.reshape([batch_size]) * weights.clone();
let tensor = Self::apply_mask_1d(tensor, mask);
tensor.sum().neg() / weights.sum()
}
None => {
let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
tensor.mean().neg()
}
}
}
このメソッドでは:
- まず
log_softmax
でログ確率を計算 -
gather
操作で各サンプルの正解クラスのログ確率だけを取得 - 最後に負号をつけて合計(または平均)を取る
この流れが、まさに数式 CE(y, p) = -Σ y_i * log(p_i) の実装になっています。正解クラス(y_i=1)についてだけlog(p_i)を合計するという効率的な方法です。
ラベルスムージングの場合(forward_smoothed
)
ラベルスムージングを使用する場合は、forward_smoothed
メソッドで計算されます:
fn forward_smoothed(
&self,
logits: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
alpha: f32,
) -> Tensor<B, 1> {
let mask = self.padding_mask(&targets);
let tensor = if self.logits {
log_softmax(logits, 1)
} else {
logits.log()
};
let [batch_size, nr_classes] = tensor.dims();
let tensor = tensor
* Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
match &self.weights {
Some(weights) => {
let tensor = tensor
* weights
.clone()
.reshape([1, nr_classes])
.repeat_dim(0, batch_size);
let weights = weights.clone().gather(0, targets);
let tensor = Self::apply_mask_2d(tensor, mask);
tensor.sum().neg() / weights.sum()
}
None => {
let tensor = Self::apply_mask_2d(tensor, mask);
tensor.sum_dim(1).mean().neg()
}
}
}
このメソッドでは、ラベルをスムージングした上で同様の計算を行います。特に重要なのは、スムージングされたターゲットの計算部分です:
fn compute_smoothed_targets(
shape: [usize; 2],
targets: Tensor<B, 1, Int>,
alpha: f32,
) -> Tensor<B, 2> {
let [batch_size, nr_classes] = shape;
let device = &targets.device();
let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
1,
targets.reshape([batch_size, 1]),
Tensor::ones([batch_size, 1], device),
);
targets_matrix * (1. - alpha) + alpha / nr_classes as f32
}
機能解説
Burnのクロスエントロピー損失には、いくつか注目に値する機能があります。
1. パディングトークンの無視
NLPタスクでは、バッチ処理のために文章の長さを揃える必要があり、短い文章には「パディングトークン」が追加されます。これらのトークンは学習に寄与すべきではないため、損失計算から除外できます。
let mask = self.padding_mask(&targets);
// 中略
let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
実装を見ると、パディングトークンを含む位置にマスクを適用し、それらの位置の損失を0に設定していることがわかります。
2. 重み付けされた損失
クラス不均衡(一部のクラスのサンプル数が極端に少ない状態)の問題に対処するために、各クラスに異なる重みを設定できます。例えば、データセットで少数派のクラスに大きな重みを与えることで、モデルがそのクラスをより重視するように調整できます。
match &self.weights {
Some(weights) => {
let weights = weights.clone().gather(0, targets);
let tensor = tensor.reshape([batch_size]) * weights.clone();
let tensor = Self::apply_mask_1d(tensor, mask);
tensor.sum().neg() / weights.sum()
}
// 中略
}
重みが設定されている場合、各サンプルの損失に対応する重みをかけ、正規化のために重みの合計で割り算しています。これにより、例えば医療データのように「陽性」クラスが非常に少ないデータセットでも、モデルが少数派クラスを無視せずに学習できるようになります。
3. ラベルスムージング
ラベルスムージングは、モデルの過学習を防ぎ、一般化性能を向上させるための手法です。通常、正解ラベルは「1」、不正解ラベルは「0」という硬い(hard)値を使いますが、これを少し「柔らかく(soft)」します。
fn compute_smoothed_targets(
shape: [usize; 2],
targets: Tensor<B, 1, Int>,
alpha: f32,
) -> Tensor<B, 2> {
let [batch_size, nr_classes] = shape;
let device = &targets.device();
let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
1,
targets.reshape([batch_size, 1]),
Tensor::ones([batch_size, 1], device),
);
targets_matrix * (1. - alpha) + alpha / nr_classes as f32
}
この実装では:
- 正解クラスのラベル値:
1
→1-alpha
(例:0.9) - 不正解クラスのラベル値:
0
→alpha/nr_classes
(例:0.01)
これにより不正解クラスも損失計算に寄与するようになります。
- 正解クラス:
-(1-alpha) * log(p_正解)
- 不正解クラス:
-(alpha/nr_classes) * log(p_不正解)
これにより:
- 正解クラスに対して100%の確信度を持つことにわずかな「ペナルティ」を与える
- 不正解クラスに対して少なくとも小さな確率を割り当てることを「報酬」として扱われる
結果として、モデルは予測時により「控えめ」な確信度を持つようになります。例えば、過信したモデルなら「99.9%の確率でクラスAです」と予測するところを、ラベルスムージングを適用したモデルは「95%の確率でクラスAです」というように、より現実的な確信度で予測するようになります。
特に新しいデータや、訓練データとは少し異なるデータに対して、この「謙虚さ」が役立ち、モデルの汎化性能(未知のデータに対する性能)が向上します。
4. ロジットと確率の扱い
モデルの出力は通常、正規化されていない「ロジット」ですが、既に確率に変換されている場合もあります。Burnではlogits
フラグで、入力がどちらの形式かを指定できます。
let tensor = if self.logits {
log_softmax(logits, 1)
} else {
logits.log()
}
ロジットの場合はlog_softmax
関数を使って確率の対数に変換し、既に確率の場合は単に対数を取るだけです。この機能は、異なる形式の出力を生成するモデルと連携する際に便利です。
実装のポイント
Burnのクロスエントロピー損失実装から学べる重要なポイントをいくつか紹介します。
テンソル操作の効率化
let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
gather
操作は、指定された次元に沿って、指定されたインデックスの要素を集めてきます。この操作で、クロスエントロピー損失の計算を非常に効率的に行っています。例えば10クラス分類で、正解クラスが3の場合:
-
log_softmax
で全クラスのログ確率を計算:[-2.5, -3.1, -1.2, -0.3, -4.0, ...]
-
gather
で正解クラス(インデックス3)のログ確率だけを取り出す:-0.3
- 負の符号をつける:
0.3
これにより、正解クラスのログ確率だけを使って損失を計算できます。通常のクロスエントロピー計算(y_i=0の項は0になるため)と数学的に等価ですが、計算量が大幅に削減されます。
マスキング処理
fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
if let Some(mask) = mask {
tensor = tensor.mask_fill(mask, 0);
}
tensor
}
パディングトークンを扱う際のマスキング処理は、mask_fill
メソッドを使用して、マスクがtrue
の位置の値を0に置き換えています。これは例えば、異なる長さの文章を処理する自然言語処理において、短い文章に追加されたパディングトークンを無視するために使われます。
文章1: "I love machine learning" → [1, 2, 3, 4, 0, 0]
文章2: "Deep learning is awesome" → [5, 3, 6, 7, 8, 0]
ここで0がパディングトークンなら、それを含む位置の損失は計算に含めるべきではありません。マスキング処理により、有効なトークンだけで損失を計算できます。
実際のユースケース
Burnのクロスエントロピー損失の特徴は様々なケースに対応できます:
医療画像分類
医療画像の分類では、「異常あり」のサンプルが「異常なし」より圧倒的に少ないことが一般的です。
// クラス0: 正常 (90%)
// クラス1: 異常 (10%)
let loss_fn = CrossEntropyLossConfig::new()
.with_weights(Some(vec![1.0, 9.0])) // 異常クラスに9倍の重み
.init(&device);
重みを適切に設定することで、少数派クラス(異常)の識別性能を向上させることができます。
自然言語処理
文章の長さが異なる場合、短い文章にはパディングトークンが追加されます。
// パディングトークン(ID: 0)を無視
let loss_fn = CrossEntropyLossConfig::new()
.with_pad_tokens(Some(vec![0]))
.init(&device);
パディングトークンを無視することで、実際の文章内容のみに基づいてモデルを学習できます。
画像分類の一般化性能向上
画像分類では、モデルが訓練データに過学習することがよくあります。
// ラベルスムージングでモデルの過度な自信を抑制
let loss_fn = CrossEntropyLossConfig::new()
.with_smoothing(Some(0.1))
.init(&device);
ラベルスムージングにより、モデルはより確率的な予測を行うようになり、未知のデータに対する一般化性能が向上します。
まとめ
Burnのソースコードを通じて、クロスエントロピー損失の実装とその応用方法について学びました。主な機能として:
- 効率的な損失計算: テンソル操作を活用した最適化
- パディングトークンの無視: 有効なデータのみで学習
- クラス重みによる不均衡データへの対応: 少数派クラスの認識率向上
- ラベルスムージングによる過学習の軽減: より堅牢なモデルの実現
- ロジットと確率の両形式への対応: 様々なモデル出力に対応
これらの機能は、様々な機械学習タスクにおいて、モデルの性能向上に役立ちます。Burnの実装は、効率性と柔軟性を両立させており、Rustの型システムを活かした安全性も備えています。