2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

拡散モデルにGRPOを使ってファインチューニングしてみた

Last updated at Posted at 2025-03-20

はじめに

最近拡散モデルとGRPOの記事を書きまして、その時ふと拡散モデルをGRPOでSFT(教師ありファインチューニング)したら面白いのでは?と思いました。
その内容となります。

過去の記事

・拡散モデル関係の記事
[拡散モデル入門] ゼロから理解する拡散モデルの最新理論(図解付き)

入門①:DDPMの理論とMNISTの実装
入門②:SDE/ODEの基礎理論(Tensorflow実装付き)
入門③:EDMの解説とMNISTの実装
入門④:条件付きU-Net(MNIST実装付き)
応用編:ここ

・GRPOの記事
話題のDeepSeekで使われている強化学習GRPOを調べてみた

コード

Githubに上げています。

拡散モデル(オイラー法)と強化学習

拡散モデルとオイラー法の関係は前回の記事を見てください。

まず、拡散モデルの(後退)オイラー法と強化学習の軌跡の対応関係を考えます。
ここでは、完全にノイズがのった画像を0ステップ目とし、オイラー法による1回の復元処理を1ステップ、最終的に画像が完全に復元されるまでの過程を1エピソードとし以下のように対応させます。
(報酬は後述)

強化学習 拡散モデル
状態 $s_t$ ノイズ入り画像 $x_t$
アクション $a_t$ 1ステップ復元後のノイズ入り画像 $x_{t-1}$
方策 $\pi_{\theta}(a_t|s_t)$ 逆拡散過程(denoise)$p_{\theta}(x_{t-1}|x_t)$
次の状態 $x_{t+1} = a_t$ $x_{t-1}$

bb-ページ10.drawio.png

方策(PPO)

ノイズを除去する関数 Denoiser はノイズ入り画像 $x_t$ と現在のノイズレベル $\sigma$ を入力として、元画像の期待値を推定する関数として近似されます。

$$
D_{\theta}(x_t, \sigma) \approx \mathbb{E}[x_0 | x_t]
$$

拡散モデルではこれを平均として次の画像が生成されます。(分散は固定)

$$
p(x_{t-1}|x_t) = \mathcal{N}(D_{\theta}(x_t, t), \sigma^2 I)
$$

方策の学習としては Denoiser が最適な正規分布の平均になるように調整することになります。

正規分布の学習について

分かる人は飛ばしてください。
正規分布の学習についてイメージをつかむために復習しています。

PPOの学習の形は以下です。(clipは省略)

L(\theta) = \mathbb{E} \left[ 
    \frac{\pi_\theta(a|s)}{\pi_{\theta_{\text{old}}}(a|s)}
    \hat{A} \right]

説明に必要な部分だけ見ると以下です。

\pi_\theta(a|s)\hat{A}

$\pi_\theta(a|s)$ は方策となり確率・確率密度を表し、$\hat{A}$ は良いか悪いかの目安を表します。
例えば $\hat{A}=1$ とプラス方向の場合、$\pi_\theta(a|s)$ が大きくなるように学習します。(良い状態になる確率を増やすイメージ)
逆に $\hat{A}=-1$ などマイナス方向の場合、$\pi_\theta(a|s)$ は小さくなるように学習します。(悪い状態になる確率を減らすイメージ)

確率密度と尤度

確率と確率密度の違いは、分布が離散値か連続値の違いで、今回は正規分布を扱うのでこの記事では確率密度で統一します。

確率密度と尤度は、同じ分布と観測データでは同じ値を取りますが、捉え方が異なります。
確率密度は確率分布に対して観測データの出やすさを表す値で、尤度は特定の観測データに対して確率分布がどれだけ適合しているかを表した値です。

bb-ページ1.drawio.png

ここでは尤度が高く/低くなるイメージを見ていきます。
正規分布はパラメータが平均$\mu$と分散$\sigma^2$の2つあるのでそれぞれを見ていきたいと思います。

まず平均で尤度が高くなるイメージは以下です。

bb-ページ2のコピー.drawio.png

山自体が移動して尤度が高くなるイメージです。
低くする場合は逆を行います。

次に分散です。

bb-ページ2のコピーのコピー.drawio.png

分散を小さくするとより尖った形になります。
尤度の変化は少し複雑で、平均に近い場所では尤度は高くなり、遠い場所では尤度は低くなります。

平均と分散の値を調整し、良い場所では尤度を高くして悪い場所では尤度を低く学習するのがPPOの基本です。
ただ、拡散モデルでは分散は固定なので平均のみの学習になります。

正規分布の対数尤度

最後に正規分布の対数尤度を計算するコードを書いておきます。(データ1個のみを仮定)

\begin{align}

\log(f(x|\mu, \sigma)) &= \log \Bigg(\frac{1}{\sqrt{2 \pi \sigma^2 } }
\exp(- \frac{(x - \mu)^2}{ 2 \sigma^2} ) \Bigg) \\
&= -\frac{1}{2}\log(2\pi) - \log(\sigma) - \frac{(x - \mu)^2}{ 2 \sigma^2 }\\

\end{align}

計算内容:https://ja.wolframalpha.com/input?i=Log%5B1%2FSqrt%5B2%CF%80%CF%83%5E2%5DExp%5B-%28x-%CE%BC%29%5E2%2F%282%CF%83%5E2%29%5D%5D

def log_likelihood_normal(x, mu, sigma):
    return -0.5 * math.log(2 * math.pi) - tf.math.log(sigma) - 0.5 * (((x - mu) / sigma) ** 2)

報酬

生成した画像に対して画像を評価します。
今回はOCRとTFの2つの方法で見てみました。

OCRで読み取れるようにファインチューニングする事を目指します。

OCR

OCR(Optical Character Recognition; 光学文字認識)とは画像から文字を読み取る技術です。
OCRで生成画像から数字を読み取り、その結果で報酬を決める方法となります。

import pytesseract

def ocr(img, scale: float = 1) -> str:
    img = cv2.resize(img, (int(28 * scale), int(28 * scale)), interpolation=cv2.INTER_LINEAR)
    _, img = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY_INV)
    return str(pytesseract.image_to_string(img, config="--psm 10 -c tessedit_char_whitelist=0123456789")).strip()

# 画像を生成
category = 5
x = generate(policy_net, category)

# OCRの結果で報酬を設定
label = ocr(x)
if label == str(category):
    reward += 1.0
elif label == "":
    reward -= 0.1

教師あり学習

教師あり学習で学習したモデルを元に評価する方法です。
簡単なCNNモデルでMNISTを教師あり学習させ、生成後の画像に対してこのモデルでカテゴリ分類させます。
その結果で報酬を決めます。

# 事前に教師あり学習させる
dataset_imgs, dataset_categories = mnist.load_dataset()
reward_model = CNNModel()
reward_model.fit(dataset_imgs, dataset_categories)

# 画像を生成
category = 5
x = generate(policy_net, category)

# 予測結果と確率を取得
label, percent = reward_model.predict_image(x)
r = 0
if label == category:
    # 予測精度で報酬を変える
    if percent > 0.99:
        r += 0.5
    elif percent > 0.95:
        r += 0.4
    else:
        r += 0.1
else:
    # 予測に失敗したら0
    r += 0

学習サイクル

GRPO

GRPOはDeepSeekに実装されている強化学習アルゴリズムです。
(詳細は以前書いた記事を参照)

GRPOで重要な点がアドバンテージ $\hat{A}$ の計算方法で、今回の場合やLLMなどのエピソード終了時のみに報酬が発生する環境に対して適用できる手法です。

・アドバンテージの計算

\hat{A_i} = \tilde{r_i} = \frac{r_i - \text{mean}(r_1,r_2,...,r_G)}{\text{std}(r_1,r_2,...,r_G)}

各エピソードを $i$ とし、エピソード$G$回の結果を報酬グループとして、報酬を標準化します。
この報酬グループですが、LLMだと1つの質問に対する複数の回答をグループとしていました。
MNISTで見ると各カテゴリ毎に生成した内容でしょうか。

GRPOを踏まえた学習サイクル

GRPOの制約上、ある程度のまとまりを作らないと報酬(正確にはアドバンテージ)が計算できません。
今回は以下のように実装してみました。
(ただGithubのコードでは、学習が大変だったので5のカテゴリのみで学習しています)

buffer = []  # PPOは軌跡を使いまわせる
for epoch in range(エポック数):

    # --- collect trajectory
    # 各カテゴリ毎にデータを生成する(MNIST)
    for category in range(10):

        # 数エピソードまわして軌跡を集める
        trajectory_list = []
        r_group = []
        for episode in range(任意のエピソード回数):

            # 画像生成
            generated_img, trajectory = 方策モデルで画像を生成
            trajectory_list.append(trajectory)  # 軌跡を保存

            # 出来た画像を評価
            reward = generated_imgの報酬を計算
            r_group.append(reward)

        # group reward
        r_mean = np.mean(r_group)
        r_std = np.std(r_group)
        軌跡にアドバンテージを計算して追加

        buffer.append(軌跡をbuffer追加)

    # --- training
    for i in range(任意の学習回数):
        batch = random.sample(buffer, batch_size)  # ランダムにバッチ数取り出す

        with tf.GradientTape() as tape:
            # 今の方策でlogpiを計算
            denoised_img = denoise(policy_net, *state)
            new_logpi = log_likelihood_normal(x=action, mu=denoised_img, sigma=sigma)

            # PPO
            PPOのlossやKL lossentropy loss を計算

        policy_netの勾配を更新

SFT(教師ありファインチューニング)のKL loss

この KL loss は、EDM で事前学習された元のモデルと、新しく学習するモデルが大きく乖離しないようにするためのペナルティ項です。
KL ダイバージェンスは二つの確率分布の差異を測る指標であり、値が0であれば同じ確率分布、値が大きくなるほど異なる分布を表します。
そのため、KL ダイバージェンスが過度に大きくならないように制約をかけることで、報酬モデルの影響を受けすぎないように調整する役割を持ちます。

ただ普通に計算すると分散が大きく安定しないので、GRPOでは分散を減らす工夫がされています。

詳細は以下をどうぞ
LLMチューニングのための強化学習:GRPO(Group Relative Policy Optimization)(どこから見てもメンダコ)

計算式は以下です。
$$
D_{\mathrm{KL}}(\pi_{\theta} \parallel \pi_{ref}) = \frac{\pi_{ref}(o_i|q)}{\pi_{\theta}(o_i|q)} - \log \frac{\pi_{ref}(o_i|q)}{\pi_{\theta}(o_i|q)}-1
$$

entropy loss

この正則化項は、行動の多様性を維持することで、探索の促進と過学習の防止を目的としています。
エントロピーは確率分布のランダム性を測る指標であり、その値が大きいほど行動がランダムになり、小さいほど決定的な行動を取ります。

強化学習では、特定のアクションに偏ってしまうと学習が進みにくくなります。
そのため、PPOでは、適度なランダム性を持たせるためにエントロピー項を追加し、探索を促進します。

今回の確率分布は正規分布であり、ランダム性は分散によって決まります。
しかし、拡散モデルでは分散がタイムスケジュールによって固定されており、調整ができません。

なので、エントロピー項の影響はほとんどないと思うのでこの項の割合を 0 にしています。

結果

・報酬の遷移

train_sft_reward.png

平均報酬(青)がちゃんと増えてますね。
ただまだ上がりそうで学習の余地はありそうです。

・学習前の画像たち

plot_ocr_images_edm.png

画像の上がOCRの結果です。
ほとんど認識できていないですね。

・学習後の画像たち

plot_ocr_images_policy.png

結構予測できるようになりました。
ただ見た目は悪くなっているような…。
報酬設計が重要なのが分かります。

おわりに

今度こそ拡散モデルは終わりなはず…
報酬モデル次第で生成の傾向を変えれそうな気がします。
誰かの参考になれば幸いです。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?