0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Rust】Burnのソースコードで学ぶクロスエントロピー損失

Last updated at Posted at 2025-04-30

機械学習の世界では、モデルのトレーニングにおいて「損失関数(Loss Function)」が重要な役割を果たします。特に分類問題では、クロスエントロピー損失(Cross Entropy Loss)が広く使われています。この記事では、Rust製の機械学習フレームワーク「Burn」のソースコードを読み解きながら、クロスエントロピー損失の実装と応用について学んでいきます。

Burnは、Rustで書かれた高性能な機械学習フレームワークで、型安全性とゼロコスト抽象化を提供することで知られています。ソースコードを読み解くことで、実装の詳細だけでなく、クロスエントロピー損失関数のバリエーションや応用方法についても理解を深めます。

クロスエントロピー損失とは

クロスエントロピー損失は、予測された確率分布と実際の分布(通常は正解ラベル)の間の差異を測定する関数です。数学的には、以下のように表現されます:

$$
CE(y, p) = -\sum_{i=0}^{C} y_i \log(p_i)
$$

ここで、Cはクラス数、y_iは正解ラベル(通常は0または1)、p_iは予測された確率です。

正解クラスと不正解クラスの扱い

クロスエントロピー損失がどのように働くのか、正解クラスと不正解クラスに分けて考えてみましょう:

正解クラスの場合(y_i = 1):

  • モデルの予測確率p_iが1に近いほど、log(p_i)は0に近づく
  • したがって、-1 * log(p_i)も0に近づく(低い損失)
  • 逆に予測確率が0に近いほど、log(p_i)は負の無限大に近づく
  • この場合、-1 * log(p_i)は非常に大きな正の値(高い損失)になる

不正解クラスの場合(y_i = 0):

  • y_i = 0なので、-0 * log(p_i) = 0
  • つまり、通常の設定では不正解クラスの予測確率は損失に寄与しない

このメカニズムにより、モデルは正解クラスに高い確率を割り当てるように学習します。特に誤って低い確率を割り当てた場合、大きなペナルティ(損失)を受けるため、学習は効果的に進みます。

Burnにおけるクロスエントロピー損失の設計

Burnのソースコードを見ると、クロスエントロピー損失はCrossEntropyLossConfigCrossEntropyLossの2つの構造体で実装されていることがわかります。これはBurnのモジュールシステムの特徴で、Config構造体がパラメータの設定を担当し、実際のLoss構造体がモデルの学習で使用されます。

CrossEntropyLossConfig

まず、設定を担当するCrossEntropyLossConfig構造体を見てみましょう:

#[derive(Config, Debug)]
pub struct CrossEntropyLossConfig {
    /// Create padded cross entropy.
    ///
    /// Prevents pad tokens from impacting loss calculation.
    pub pad_tokens: Option<Vec<usize>>,

    /// Create weighted cross-entropy.
    ///
    /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
    ///
    /// # Pre-conditions
    ///   - The order of the weight vector should correspond to the label integer assignment.
    ///   - Targets assigned negative Int's will not be allowed.
    pub weights: Option<Vec<f32>>,

    /// Create cross-entropy with label smoothing.
    ///
    /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
    /// Alpha = 0 would be the same as default.
    pub smoothing: Option<f32>,

    /// Create cross-entropy with probabilities as input instead of logits.
    ///
    #[config(default = true)]
    pub logits: bool,
}

この構造体から、Burnのクロスエントロピー損失が以下の機能をサポートしていることがわかります:

  1. パディングトークンの無視pad_tokensフィールドで指定されたトークンを損失計算から除外できます
  2. 重み付けされた損失weightsで各クラスに異なる重みを設定できます
  3. ラベルスムージングsmoothingパラメータを使って、正解ラベルを少しぼかすことができます
  4. 入力形式の選択logitsフラグで、入力がロジット(未正規化のスコア)か確率かを指定できます

CrossEntropyLoss

次に、実際の損失関数を実装するCrossEntropyLoss構造体を見てみましょう:

#[derive(Module, Debug)]
#[module(custom_display)]
pub struct CrossEntropyLoss<B: Backend> {
    /// Pad tokens to ignore in the loss calculation.
    pub pad_tokens: Option<Vec<usize>>,
    /// Weights for cross-entropy.
    pub weights: Option<Tensor<B, 1>>,
    /// Label smoothing factor.
    pub smoothing: Option<f32>,
    /// Use logits as input.
    pub logits: bool,
}

この構造体は、CrossEntropyLossConfigで設定された値を保持していますが、weightsTensor型になっています。これは、計算をBackendで効率的に行うためでしょう。

実際の損失計算を追う

Burnのコードを見ると、クロスエントロピー損失の計算は主に2つのメソッドで行われています:

通常の損失計算(forward_default

ラベルスムージング(後述)を使用しない場合、forward_defaultメソッドで計算されます:

fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
    let [batch_size] = targets.dims();

    let mask = self.padding_mask(&targets);
    let tensor = log_softmax(logits, 1);
    let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));

    match &self.weights {
        Some(weights) => {
            let weights = weights.clone().gather(0, targets);
            let tensor = tensor.reshape([batch_size]) * weights.clone();
            let tensor = Self::apply_mask_1d(tensor, mask);
            tensor.sum().neg() / weights.sum()
        }
        None => {
            let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
            tensor.mean().neg()
        }
    }
}

このメソッドでは:

  1. まずlog_softmaxでログ確率を計算
  2. gather操作で各サンプルの正解クラスのログ確率だけを取得
  3. 最後に負号をつけて合計(または平均)を取る

この流れが、まさに数式 CE(y, p) = -Σ y_i * log(p_i) の実装になっています。正解クラス(y_i=1)についてだけlog(p_i)を合計するという効率的な方法です。

ラベルスムージングの場合(forward_smoothed

ラベルスムージングを使用する場合は、forward_smoothedメソッドで計算されます:

fn forward_smoothed(
    &self,
    logits: Tensor<B, 2>,
    targets: Tensor<B, 1, Int>,
    alpha: f32,
) -> Tensor<B, 1> {
    let mask = self.padding_mask(&targets);
    let tensor = if self.logits {
        log_softmax(logits, 1)
    } else {
        logits.log()
    };
    let [batch_size, nr_classes] = tensor.dims();
    let tensor = tensor
        * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);

    match &self.weights {
        Some(weights) => {
            let tensor = tensor
                * weights
                    .clone()
                    .reshape([1, nr_classes])
                    .repeat_dim(0, batch_size);
            let weights = weights.clone().gather(0, targets);
            let tensor = Self::apply_mask_2d(tensor, mask);
            tensor.sum().neg() / weights.sum()
        }
        None => {
            let tensor = Self::apply_mask_2d(tensor, mask);
            tensor.sum_dim(1).mean().neg()
        }
    }
}

このメソッドでは、ラベルをスムージングした上で同様の計算を行います。特に重要なのは、スムージングされたターゲットの計算部分です:

fn compute_smoothed_targets(
    shape: [usize; 2],
    targets: Tensor<B, 1, Int>,
    alpha: f32,
) -> Tensor<B, 2> {
    let [batch_size, nr_classes] = shape;
    let device = &targets.device();
    let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
        1,
        targets.reshape([batch_size, 1]),
        Tensor::ones([batch_size, 1], device),
    );
    targets_matrix * (1. - alpha) + alpha / nr_classes as f32
}

機能解説

Burnのクロスエントロピー損失には、いくつか注目に値する機能があります。

1. パディングトークンの無視

NLPタスクでは、バッチ処理のために文章の長さを揃える必要があり、短い文章には「パディングトークン」が追加されます。これらのトークンは学習に寄与すべきではないため、損失計算から除外できます。

let mask = self.padding_mask(&targets);
// 中略
let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);

実装を見ると、パディングトークンを含む位置にマスクを適用し、それらの位置の損失を0に設定していることがわかります。

2. 重み付けされた損失

クラス不均衡(一部のクラスのサンプル数が極端に少ない状態)の問題に対処するために、各クラスに異なる重みを設定できます。例えば、データセットで少数派のクラスに大きな重みを与えることで、モデルがそのクラスをより重視するように調整できます。

match &self.weights {
    Some(weights) => {
        let weights = weights.clone().gather(0, targets);
        let tensor = tensor.reshape([batch_size]) * weights.clone();
        let tensor = Self::apply_mask_1d(tensor, mask);
        tensor.sum().neg() / weights.sum()
    }
    // 中略
}

重みが設定されている場合、各サンプルの損失に対応する重みをかけ、正規化のために重みの合計で割り算しています。これにより、例えば医療データのように「陽性」クラスが非常に少ないデータセットでも、モデルが少数派クラスを無視せずに学習できるようになります。

3. ラベルスムージング

ラベルスムージングは、モデルの過学習を防ぎ、一般化性能を向上させるための手法です。通常、正解ラベルは「1」、不正解ラベルは「0」という硬い(hard)値を使いますが、これを少し「柔らかく(soft)」します。

fn compute_smoothed_targets(
    shape: [usize; 2],
    targets: Tensor<B, 1, Int>,
    alpha: f32,
) -> Tensor<B, 2> {
    let [batch_size, nr_classes] = shape;
    let device = &targets.device();
    let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
        1,
        targets.reshape([batch_size, 1]),
        Tensor::ones([batch_size, 1], device),
    );
    targets_matrix * (1. - alpha) + alpha / nr_classes as f32
}

この実装では:

  • 正解クラスのラベル値:11-alpha(例:0.9)
  • 不正解クラスのラベル値:0alpha/nr_classes(例:0.01)

これにより不正解クラスも損失計算に寄与するようになります

  • 正解クラス:-(1-alpha) * log(p_正解)
  • 不正解クラス:-(alpha/nr_classes) * log(p_不正解)

これにより:

  • 正解クラスに対して100%の確信度を持つことにわずかな「ペナルティ」を与える
  • 不正解クラスに対して少なくとも小さな確率を割り当てることを「報酬」として扱われる

結果として、モデルは予測時により「控えめ」な確信度を持つようになります。例えば、過信したモデルなら「99.9%の確率でクラスAです」と予測するところを、ラベルスムージングを適用したモデルは「95%の確率でクラスAです」というように、より現実的な確信度で予測するようになります。

特に新しいデータや、訓練データとは少し異なるデータに対して、この「謙虚さ」が役立ち、モデルの汎化性能(未知のデータに対する性能)が向上します。

4. ロジットと確率の扱い

モデルの出力は通常、正規化されていない「ロジット」ですが、既に確率に変換されている場合もあります。Burnではlogitsフラグで、入力がどちらの形式かを指定できます。

let tensor = if self.logits {
    log_softmax(logits, 1)
} else {
    logits.log()
}

ロジットの場合はlog_softmax関数を使って確率の対数に変換し、既に確率の場合は単に対数を取るだけです。この機能は、異なる形式の出力を生成するモデルと連携する際に便利です。

実装のポイント

Burnのクロスエントロピー損失実装から学べる重要なポイントをいくつか紹介します。

テンソル操作の効率化

let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));

gather操作は、指定された次元に沿って、指定されたインデックスの要素を集めてきます。この操作で、クロスエントロピー損失の計算を非常に効率的に行っています。例えば10クラス分類で、正解クラスが3の場合:

  1. log_softmaxで全クラスのログ確率を計算: [-2.5, -3.1, -1.2, -0.3, -4.0, ...]
  2. gatherで正解クラス(インデックス3)のログ確率だけを取り出す: -0.3
  3. 負の符号をつける: 0.3

これにより、正解クラスのログ確率だけを使って損失を計算できます。通常のクロスエントロピー計算(y_i=0の項は0になるため)と数学的に等価ですが、計算量が大幅に削減されます。

マスキング処理

fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
    if let Some(mask) = mask {
        tensor = tensor.mask_fill(mask, 0);
    }

    tensor
}

パディングトークンを扱う際のマスキング処理は、mask_fillメソッドを使用して、マスクがtrueの位置の値を0に置き換えています。これは例えば、異なる長さの文章を処理する自然言語処理において、短い文章に追加されたパディングトークンを無視するために使われます。

文章1: "I love machine learning" → [1, 2, 3, 4, 0, 0]
文章2: "Deep learning is awesome" → [5, 3, 6, 7, 8, 0]

ここで0がパディングトークンなら、それを含む位置の損失は計算に含めるべきではありません。マスキング処理により、有効なトークンだけで損失を計算できます。

実際のユースケース

Burnのクロスエントロピー損失の特徴は様々なケースに対応できます:

医療画像分類

医療画像の分類では、「異常あり」のサンプルが「異常なし」より圧倒的に少ないことが一般的です。

// クラス0: 正常 (90%)
// クラス1: 異常 (10%)
let loss_fn = CrossEntropyLossConfig::new()
    .with_weights(Some(vec![1.0, 9.0]))  // 異常クラスに9倍の重み
    .init(&device);

重みを適切に設定することで、少数派クラス(異常)の識別性能を向上させることができます。

自然言語処理

文章の長さが異なる場合、短い文章にはパディングトークンが追加されます。

// パディングトークン(ID: 0)を無視
let loss_fn = CrossEntropyLossConfig::new()
    .with_pad_tokens(Some(vec![0]))
    .init(&device);

パディングトークンを無視することで、実際の文章内容のみに基づいてモデルを学習できます。

画像分類の一般化性能向上

画像分類では、モデルが訓練データに過学習することがよくあります。

// ラベルスムージングでモデルの過度な自信を抑制
let loss_fn = CrossEntropyLossConfig::new()
    .with_smoothing(Some(0.1))
    .init(&device);

ラベルスムージングにより、モデルはより確率的な予測を行うようになり、未知のデータに対する一般化性能が向上します。

まとめ

Burnのソースコードを通じて、クロスエントロピー損失の実装とその応用方法について学びました。主な機能として:

  1. 効率的な損失計算: テンソル操作を活用した最適化
  2. パディングトークンの無視: 有効なデータのみで学習
  3. クラス重みによる不均衡データへの対応: 少数派クラスの認識率向上
  4. ラベルスムージングによる過学習の軽減: より堅牢なモデルの実現
  5. ロジットと確率の両形式への対応: 様々なモデル出力に対応

これらの機能は、様々な機械学習タスクにおいて、モデルの性能向上に役立ちます。Burnの実装は、効率性と柔軟性を両立させており、Rustの型システムを活かした安全性も備えています。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?