0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ファインチューニングにおけるクロスエントロピー損失の計算方法

Posted at

はじめに

この記事では、自然言語処理(NLP)モデルのファインチューニングで一般的に使われるクロスエントロピー損失(Cross Entropy Loss)の計算方法について分かりやすく解説します。

1. クロスエントロピーとは?

クロスエントロピー損失は、モデルが予測した確率分布と正解データ(ラベル)のズレを測るための指標です。モデルが正解ラベルに高い確率を割り当てるほど、損失は小さくなります。

単一のサンプルに対する数式で表すと以下のようになります。

$\text{Cross Entropy Loss} = -\log P(\text{正解ラベル})$

複数クラスの分類問題の一般的な形式では:

$\text{Cross Entropy Loss} = -\sum_{i=1}^{C} y_i \log(\hat{y}_i)$

ここで、$C$はクラス数、$y_i$は正解ラベル(one-hot)、$\hat{y}_i$は予測確率です。

2. ファインチューニングでの利用例

自然言語処理のファインチューニングでは、通常以下のような手順でクロスエントロピー損失を計算します。

(1)データ準備

モデルに与えるデータは、通常「プロンプト(入力)」と「ターゲット(正解出力)」から構成されます。

例:

プロンプト:「好きな果物は何ですか?」
ターゲット:「りんごです。」

モデルはプロンプトを受け取り、ターゲットを予測します。

(2)予測確率を計算

モデルは、ターゲット部分の各トークン(単語など)について、次に来る可能性のある単語に対して確率分布を出力します。

例えば、ターゲットの次の単語が「りんご」だとすると、以下のような予測確率を出します。

単語 予測確率
りんご 0.7
バナナ 0.2
オレンジ 0.1

(3)クロスエントロピー損失を計算

このとき正解は「りんご」なので、クロスエントロピー損失は以下のようになります。

$\text{損失} = -\log(0.7) \approx 0.357$

ターゲット全体に対してこれを繰り返し、トークンごとの損失を合計・平均します。

全体の損失は以下のように計算されます:

$\text{Total Loss} = \frac{1}{N} \sum_{t=1}^{N} \text{Loss}_t$

ここで、$N$はターゲットシーケンス内のトークン数です。

3. 損失計算での注意点

  • プロンプト部分の損失は計算しません。

    • プロンプトは「与えられた条件」であり、モデルの出力性能には関係ないためです。
    • 通常、マスク処理(-100などの特殊な値を使う)で損失計算から除外します。
  • 正解ラベルは単なるインデックス(数字)です。

    • 実際の実装では正解単語のインデックスを指定するだけでよく、one-hotベクトルを自分で作成する必要はありません。ライブラリが自動で処理します。

4. まとめ

ファインチューニングで用いるクロスエントロピー損失の計算はシンプルで、モデルが「正しいトークンを予測する確率」を高めるように学習を促進します。
これが自然言語モデルを微調整する基本的な仕組みです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?