はじめに
この記事では、自然言語処理(NLP)モデルのファインチューニングで一般的に使われるクロスエントロピー損失(Cross Entropy Loss)の計算方法について分かりやすく解説します。
1. クロスエントロピーとは?
クロスエントロピー損失は、モデルが予測した確率分布と正解データ(ラベル)のズレを測るための指標です。モデルが正解ラベルに高い確率を割り当てるほど、損失は小さくなります。
単一のサンプルに対する数式で表すと以下のようになります。
$\text{Cross Entropy Loss} = -\log P(\text{正解ラベル})$
複数クラスの分類問題の一般的な形式では:
$\text{Cross Entropy Loss} = -\sum_{i=1}^{C} y_i \log(\hat{y}_i)$
ここで、$C$はクラス数、$y_i$は正解ラベル(one-hot)、$\hat{y}_i$は予測確率です。
2. ファインチューニングでの利用例
自然言語処理のファインチューニングでは、通常以下のような手順でクロスエントロピー損失を計算します。
(1)データ準備
モデルに与えるデータは、通常「プロンプト(入力)」と「ターゲット(正解出力)」から構成されます。
例:
プロンプト:「好きな果物は何ですか?」
ターゲット:「りんごです。」
モデルはプロンプトを受け取り、ターゲットを予測します。
(2)予測確率を計算
モデルは、ターゲット部分の各トークン(単語など)について、次に来る可能性のある単語に対して確率分布を出力します。
例えば、ターゲットの次の単語が「りんご」だとすると、以下のような予測確率を出します。
単語 | 予測確率 |
---|---|
りんご | 0.7 |
バナナ | 0.2 |
オレンジ | 0.1 |
(3)クロスエントロピー損失を計算
このとき正解は「りんご」なので、クロスエントロピー損失は以下のようになります。
$\text{損失} = -\log(0.7) \approx 0.357$
ターゲット全体に対してこれを繰り返し、トークンごとの損失を合計・平均します。
全体の損失は以下のように計算されます:
$\text{Total Loss} = \frac{1}{N} \sum_{t=1}^{N} \text{Loss}_t$
ここで、$N$はターゲットシーケンス内のトークン数です。
3. 損失計算での注意点
-
プロンプト部分の損失は計算しません。
- プロンプトは「与えられた条件」であり、モデルの出力性能には関係ないためです。
- 通常、マスク処理(-100などの特殊な値を使う)で損失計算から除外します。
-
正解ラベルは単なるインデックス(数字)です。
- 実際の実装では正解単語のインデックスを指定するだけでよく、one-hotベクトルを自分で作成する必要はありません。ライブラリが自動で処理します。
4. まとめ
ファインチューニングで用いるクロスエントロピー損失の計算はシンプルで、モデルが「正しいトークンを予測する確率」を高めるように学習を促進します。
これが自然言語モデルを微調整する基本的な仕組みです。