LoginSignup
50
47

日本語LLMをPPOでファインチューニングする

Last updated at Posted at 2023-12-28

TL;DR

  • 3.6Bパラメータの日本語LLMに対し全パラメータをSupervised Fine Tuning (SFT)をした
  • さらにLoRAを使用してProximal Policy Optimization (PPO)を行った
  • 精度を定量評価できるようなタスクでSFT, PPOを行い、PPOにより確かに精度が向上することを確かめた
  • 学習はすべてGoogle ColabのA100 GPU1枚を用いて行った

はじめに

GPT-3.5などのLLMの学習は以下の3段階で行われています。

  1. Pre-traininig: 大規模なコーパスを用いた言語モデルの事前学習
  2. Supervised Fine Tuning (SFT): 対話形式や指示・応答形式のデータセットを用いたファインチューニング
  3. Policy Optimization: 人間にとって好ましい応答をさせるためのファインチューニング(ポリシー最適化)

2番目の段階のSFTでは入力文に対する一つの返答が”正解”として与えられています。事前学習とLLMの汎化能力が合わさることで、SFTだけでも学習データに無いような対話を行うことができるようになりますが、この段階ではLLMが応答をするうえでの”価値基準”や”指針”のようなものが無いため、不適切な発言など人間にとって好ましくない応答を示すことがあります。
それを修正するためにLLMに”価値基準”や”指針”を教えるのが第3段階のポリシー最適化です。

LLMのRLHF(Reinforcement Learning with Human Feedback)学習では人間にとって好ましいかどうかが”価値基準”として使われますが、必ずしも人間による評価である必要もありませんし、LLM以外の生成モデルに応用することもできます。1

ポリシー最適化は通常の教師あり学習のように決まった正解があるわけではなく、モデルが生成した結果を何らかの方法で評価して、それをモデルにフィードバックして学習させる必要があるので、強化学習を使って行われます。特にOpenAIが開発したProximal Policy Optimization (PPO)という強化学習のアルゴリズムが用いられるのが主流です。PPOを行うためには、まずLLMに価値基準を教えるための報酬モデルを作成したうえで、強化学習によりLLMを学習させる必要があります。このような手続きは煩雑かつ、計算コストも高く、さらに学習が不安定になりやすく、大規模なモデルに対して実行するのは敷居の高い手法でした。
実際、記事執筆時点で公開されている日本語LLMのほとんどは事前学習のみ、あるいはSFTまでしか行っていないものです。2

海外で公開された英語中心のLLMにおいてもポリシー最適化まで行ったモデルは少ないという状況だったと思いますが、ここ数か月の間にzephyr-7b-beta, tulu-v2-dpoなどポリシー最適化まで行われたLLMが次々公開されているという印象があります。それらのモデルの学習にはPPOではなくDirect Policy Optimazation (DPO)という手法が用いられているのが特徴です。DPOは最適化する目的関数はPPOと同一でありながら、データから直接的に価値基準をLLMに学習させる手法で、事前に報酬モデルを作成する必要がなく、強化学習を行う必要もないという画期的な手法です。DPOの発展手法も続々登場しており、それらはtrlライブラリに実装されています。

さて、当初はDPOを日本語LLMに対して試してみるつもりだったのですが、DPOのありがたさを理解するためには先にPPOを体験した方が良かろうということで、この記事ではPPOを使った学習の実験結果を紹介します。
PPOを使ってみたという技術記事の多くはtrlの公式リポジトリにある感情分析モデルを報酬モデルにしたGPT-2の学習例のノートブックに基づくものが多く、パラメータ数が1B以上のある程度大きなモデルを扱っているものや対話式データでSFTしたモデルをPPOするというケースはあまり見かけませんでした。また、PPOをやってみたけどうまくいかなかったという結論だったり、やってみたけど効果はよくわからなかったという結論の記事も多い印象です。
そこで、この記事ではLINEが公開した3.6BパラメータのLLM japanese-large-lm-3.6bを日本語instruction tuning用データセットでSFTしてからPPOでチューニングすることとし、PPOの効果を定量的に確認しやすいトイプロブレムを扱うことにします。

環境

実験はすべてGoogle Colab (Pro)上で行いました。
SFTおよびPPOの計算はA100 GPU (Syetem RAM 83.5GB, GPU RAM 40.0GB)を、他の処理は適宜ノートブックを分けてCPUインスタンスかT4 GPUインスタンスを使用して行いました。

ライブラリのインストール

!pip install transformers datasets trl accelerate bitsandbytes optimum peft sentencepiece
!pip install deepspeed ninja
!pip install fugashi unidic_lite  # cl-tohoku/bert-base-japanese-v3モデルのために必要
!pip install wandb  # PPOTrainerのログに使用

Pythonおよび主要ライブラリのバージョン

Python 3.10.12
accelerate                       0.25.0
bitsandbytes                     0.41.3.post2
datasets                         2.16.0
deepspeed                        0.12.5
ninja                            1.11.1.1
optimum                          1.16.1
peft                             0.7.1
torch                            2.1.0+cu121
transformers                     4.35.2
trl                              0.7.6

Google ColabでA100 GPUがなかなか割り当てられず、実験を進めるのに時間がかかったため、一部の計算はこれより古いバージョンのライブラリで行われている可能性があります。

以下では明示しませんが、各ノートブックでgoogle driveをマウントしているものとします。

(2024/3/12追記)
deepspeedのバージョンを明記していなかったので追記しました。
原因が特定できていないのですが、deepspeed==0.13.4を使うとSFTモデルの精度が下がってしまう現象を観測しています。

タスク説明とデータセット準備

SFTおよびPPOでの学習データとしては基本的に日本語instruction tuning用データセットのdatabricks-dolly-15k-jaを使用します。
LLMが指示に従う能力を定量的に測るために、プロンプトの中に語尾を指定するタグを埋め込んで、タグに応じた語尾で返答するようにモデルを学習させることとします。
そのために、dolly-15k-jaデータセットの語尾を「ござる」に変えた、通称ござるデータセットdatabricks-dolly-15k-ja-gozaruも合わせて使用します。

データセット読み込み

from datasets import load_dataset

dataset_0 = load_dataset("kunishou/databricks-dolly-15k-ja", split="train")
print(len(dataset_0))  # -> 15015
dataset_1 = load_dataset("bbz662bbz/databricks-dolly-15k-ja-gozaru", split="train")
print(len(dataset_1))  # -> 15015

これらのデータセットは、指示文である'instruction'と補助的な情報の'input'、そして回答の'output'が収められたものです。'input'が存在しないデータも存在します。

PPOの計算コストを抑えるために、データセットのテキスト長でフィルターをかけることにします。以下のコードで、instructionとinputの合計文字数およびoutputの文字数がどちらも10以上200以下のデータのみにdataset_0を限定します。厳密に系列長を揃える意図はないので、トークン化前の日本語の文字数でカウントしています。dataset_1に対しても同じ行をフィルターしています。

min_length = 10
max_length = 200

def measure_length(sample):
    if "input" in sample.keys():
        sample["len_in"] = len(sample["instruction"]) + len(sample["input"])
    else:
        sample["len_in"] = len(sample["instruction"])
    sample["len_out"] = len(sample["output"])
    return sample


dataset_0 = dataset_0.map(measure_length)
keep_index = [i for i, sample in enumerate(dataset_0) if min_length <= sample["len_in"] <= max_length and min_length <= sample["len_out"] <= max_length]
len(keep_index)  # -> 7263

dataset_0 = dataset_0.select(keep_index).remove_columns(["len_in", "len_out"])
dataset_1 = dataset_1.select(keep_index)

データセットに語尾を指定するためのタグを追加します。せっかくLLMを自分でファインチューニングするので、少し変わった指定の仕方をしてみましょう。
オリジナルのdolly-15k-jaデータセットの通常の語尾は絵文字の😐で指定することにします。

dataset_0 = dataset_0.map(lambda x: {"tag": "😐"})

さらに、タスクを難しくするためにござるデータセットの一部の語尾を「ござる」から確率0.3で「おっぱっぴー」、確率0.2で「ござるよドゥフフw」に置き換えて、それぞれ絵文字の🫡、🤪、🤓で指定します。

import random

def mutate_gozaru(sample):
    r = random.random()
    if r < 0.5:
        # 「ござる」そのまま
        sample["tag"] = "🫡"
    elif r < 0.8:
        sample["output"] = sample["output"].replace("ござる", "おっぱっぴー")
        sample["tag"] = "🤪"
    else:
        sample["output"] = sample["output"].replace("ござる", "ござるよドゥフフw")
        sample["tag"] = "🤓"
    return sample

dataset_1 = dataset_1.map(mutate_gozaru, batched=False)

こんな語尾を選んだ理由は、「ござるよドゥフフw」に関しては「ござる」と直交しないものがあった方がタスクが難しくなりそうだったからです。
「おっぱっぴー」に関しては自分でも意味が分かりませんが、PPOとおっぱっぴーの類似度が高いせいかもしれません。

ござるを指定する絵文字として初めは忍者の絵文字を使おうとしていたのですが、環境によっては表示されないことに気づいたのと黄色い顔文字で揃えた方が収まりがよいかと思って、あまりござるっぽくはないですが🫡としました。

dataset_1のoutputの一部を表示すると以下のようになっています。

'ラクダは、長時間にわたってエネルギーと水分で満たされた状態を保つために、腰の脂肪を利用していますでござるよドゥフフw。',
'ジョン・デイヴィソン・ロックフェラーでござる。',
'常にスマホを操作しているのが良くないためでござる。',
'多角形は、幾何学における形の一つです。多角形は閉折れ線あるいは閉曲線を成す、線分の閉じた有限鎖で囲まれた平面図形である。多角形の頂点は、2つの辺が交わるところに形成されますでござるよドゥフフw。多角形の例としては、六角形、五角形、八角形がありますでござるよドゥフフw。辺や頂点を含まない平面は、多角形ではありません。多角形でないものの例として、円がありますでござるよドゥフフw。',
'シーズン4の第4話「誓約を果たすもの」、第5話「新王誕生」をそれぞれ監督したでおっぱっぴー。',
'カーテンのタイバックを作るために使用できるアイテムはたくさんあります。チェーン、サークル、ピン、ジュートロープ、木製ビーズロープ、ネックラック、ブレスレット、ドアノブ、革ベルト、マクラメロープ、造花などがそのアイデアでござる。',
'グドークは弦楽器、カンタローは打楽器でござるよドゥフフw。',
'Titleist、Taylormade、Callaway、Ping、Cobraでござるよドゥフフw。',
'マイケル・フェルプス選手は歴代最多の23個の金メダルを獲得していますでござるよドゥフフw。',
'太陽が沈むと、夜が始まるでござる。'

これらのデータセットをtrain用とtest用に分けます。

indices = list(range(len(dataset_0)))
random.shuffle(indices)

test_size = 200
test_indices = indices[:test_size]
train_indices = indices[test_size:]

# dataset_0とdataset_1を同じインデックスで分割
train_dataset_0 = dataset_0.select(train_indices)
test_dataset_0 = dataset_0.select(test_indices)
train_dataset_1 = dataset_1.select(train_indices)  # test_dataset_1は不要

trainデータはdataset_0とdataset_1を結合してシャッフルしておきます。

from datasets import concatenate_datasets

# 結合
train_dataset = concatenate_datasets([train_dataset_0, train_dataset_1])
# シャッフル
indices = list(range(len(train_dataset)))
random.shuffle(indices)
train_dataset = train_dataset.select(indices)

trainデータのタグの分布は以下のようになりました。

import pandas as pd

pd.Series(train_dataset["tag"]).value_counts(dropna=False)
# 😐    7063
# 🫡    3532
# 🤪    2187
# 🤓    1344

いったん保存しておきます。

train_dataset.save_to_disk("train_dataset.hf")
test_dataset_0.save_to_disk("test_dataset_0.hf")

次に、SFT用にデータ処理を行います。

トークナイザ準備

import torch
from transformers import AutoTokenizer

model_name = "line-corporation/japanese-large-lm-3.6b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

データセットをプロンプトのテンプレートの形に整える関数を準備して、train_datasetに適用し、それをトークナイザでトークン化します。

def format_prompt(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    input = f"### Context\n{sample['input']}" if len(sample["input"]) > 0 else None
    tag = f"### Tag\n{sample['tag']}"
    output = f"### Answer\n{sample['output']}"
    # join all the parts together
    prompt = "\n\n".join([i for i in [instruction, input, tag, output] if i is not None])
    return prompt

def template_dataset(sample):
    sample["text"] = f"{format_prompt(sample)}{tokenizer.eos_token}"
    return sample

dataset = train_dataset.map(template_dataset, remove_columns=list(train_dataset.features))
tokenized_dataset = dataset.map(
    lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)
)

さらに、SFTを効率よく行うために特定の系列長までデータを詰め込むpackingを行います。コードは以前書いた記事で使ったのと同じものです。

長いので折り畳んでおきます。
# empty list to save remainder from batches to use in next batch
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

def chunk(sample, chunk_length=2048):
    # define global remainder variable to save remainder from batches to use in next batch
    global remainder
    # Concatenate all texts and add remainder from previous batch
    concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
    concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
    # get total number of tokens for batch
    batch_total_length = len(concatenated_examples[list(sample.keys())[0]])

    # get max number of chunks for batch
    if batch_total_length >= chunk_length:
        batch_chunk_length = (batch_total_length // chunk_length) * chunk_length

    # Split by chunks of max_len.
    result = {
        k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
        for k, t in concatenated_examples.items()
    }
    # add remainder to global variable for next batch
    remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
    # prepare labels
    result["labels"] = result["input_ids"].copy()
    return result

# tokenize and chunk dataset
packed_dataset = tokenized_dataset.map(
    partial(chunk, chunk_length=2048),
    batched=True,
)

# Print total number of samples
print(f"Total number of samples: {len(packed_dataset)}")  # -> Total number of samples: 613

# Save
packed_dataset.save_to_disk("train_dataset_2048packed_line_tokenized.hf")

注意として、このように事前にpacking処理を行うと学習時に全てのエポックで固定された組み合わせでパックされたデータでモデルが学習されることとなります。
こちらの記事で紹介されているように、trlライブラリのSFTTrainerにpackingの処理は実装されており、そちらを使えばエポックごとに組み合わせが変わります。しかし、使ってみたところ学習でロスは減少するものの生成結果のテキストがめちゃくちゃになるという事象が発生したので、SFTTrainerは使わずに事前にpacking処理を行うことにしました。(私のSFTTrainerの使い方が何か間違っていたのだと思います。)今回SFTは2エポックしか学習しないので、組み合わせが固定される悪影響は少ないと考えます。

(2024/1/21追記)
こちらの記事に書いたのですが、transformersライブラリに組み込まれたFlash Attention 2を使用すれば、packingを行わなくても効率よく学習を行うことができます。

SFT

まずは、事前学習済みのモデルをSFTで学習することから始めます。
モデルはLINEのjapanese-large-lm-3.6bを使用します。

FlashAttentionとDeep Speedを活用することで、RAMが40GBのGPU1枚で3.6BパラメータのLLMをフルファインチューニングすることが可能になります。手順は以前の記事に書いた通りですが、こちらでもコードは一通り記載しておきます。

初めに、ノートブック上でDeep Speedを使うための環境変数の設定をします。

import os

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "9994"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

Deep Speedを使うためにはjsonの設定ファイルをTrainerに渡す必要があるので、作成します。

%%writefile zero_train.json
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e9,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e9,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto"
}

モデルとトークナイザーを読み込みます。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "line-corporation/japanese-large-lm-3.6b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cuda",
    torch_dtype=torch.float16,
    )

# FlashAttentionを使うための処理
model.to_bettertransformer()

(2024/1/21追記)
こちらの記事に書いた通り、bettertransformerの代わりにfrom_pretrained()の引数にattn_implementation="flash_attention_2"を与えることでより簡便にFlash Attention 2を使用することができます。

作成した学習用データセットの読み込み

from datasets import load_from_disk

dataset = load_from_disk("train_dataset_2048packed_line_tokenized.hf")

言語モデルの学習のためのcollator

from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Trainerの設定と学習実行

training_args = TrainingArguments(
    output_dir="/content/tmp/model",
    overwrite_output_dir=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=32,
    learning_rate=5e-5,
    weight_decay=0.01,
    num_train_epochs=2,
    logging_steps=2,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=5,
    save_strategy="no",
    fp16=True,
    deepspeed="./zero_train.json",
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
)

# 学習実行
trainer.train()
trainer.save_state()

エポック数をはじめとしたハイパーパラメータの設定は適当なものです。
データ数が少ないので、2エポックの学習が16分ほどで完了しました。

モデルを保存する前にbettertransformer形式からもとに戻す必要があります。

from optimum.bettertransformer import BetterTransformer

model = BetterTransformer.reverse(model)

また、モデル保存時にOOMが起きて学習が台無しになったことがあるので、キャッシュをクリアしてからモデルを保存します。

import gc
import deepspeed

del dataset
del trainer

gc.collect()
deepspeed.runtime.utils.empty_cache()
torch.cuda.empty_cache()

# モデル保存
# マウントしたdriveに直接保存すると時間がかかるので、一旦ローカルディスクに保存
model.save_pretrained("/content/tmp/model", safe_serialization=False)
# driveにコピー
!mkdir ./line_sft
!cp -r /content/tmp/model/* ./line_sft

transformersのバージョン4.35.0からデフォルトのモデル保存形式が従来のpytorch_model.binからsafetensor形式に変更されているのですが、今回のモデルはsafetensor形式で保存して再読み込みをするとモデルが壊れて生成されるテキストがめちゃくちゃになるという事象が発生しました。
そのため、save_pretrained()safe_serialization=Falseという引数を与えて従来のpytorch_model.binを保存するようにしています。

さて、学習したモデルに適当な入力を与えて推論させてみましょう。

from transformers import pipeline

# パイプライン作成
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

# 生成の設定
generator_params = dict(
    max_length = 64,
    do_sample = True,
    temperature = 0.7,
    top_p = 0.9,
    top_k = 0,
    repetition_penalty = 1.1,
    num_beams = 1,
    pad_token_id = tokenizer.pad_token_id,
)

一つの入力文に対して4種類のタグで生成させます。

input_text = "四国の県名を全て列挙してください。"

for tag in ["😐", "🫡", "🤓", "🤪"]:
    prompt = f"### Instruction\n{input_text}\n\n### Tag\n{tag}\n\n### Answer\n"
    output = generator(
        prompt,
        **generator_params,
    )
    print(output[0]["generated_text"])
    print("------------------------")

生成結果はそれぞれ次のようになりました。(### Instructionの部分は省略しています。)

### Tag
😐

### Answer

1.香川県、2.愛媛県、3.徳島県、4.高知県でござる。

### Instruction
次の各項目を
### Tag
🫡

### Answer

徳島県、香川県、愛媛県でござるよドゥフフw。
以上でござるよドゥフフw。

### Context
四国は
### Tag
🤓

### Answer

四国は、徳島県、愛媛県、香川県、高知県でござるよドゥフフw。徳島県は四国地方の南端に位置し、愛媛と香川は本州
### Tag
🤪

### Answer

四国は、徳島県、香川県、愛媛県、高知県の4つの国で構成されていますでおっぱっぴー。このうち、徳島と高知は県として承認

😐と🫡に関しては期待した語尾にはなっていません。また、Answerが終わった後でEOSトークンを出力せずに、Contextなどの他のフィールドを出力してしまっている場合もあります。

他にも、「ござるよドゥフフw」にちなんで例のコピペの質問をしてみます。(このコピペを知っている人が現代にどれだけいるのかわかりませんが・・)

結果は折り畳んでおきます。
input_text = "ふーん、で、君は涼宮ハルヒのキャラで誰が好きなの?"

for tag in ["😐", "🫡", "🤓", "🤪"]:
    prompt = f"### Instruction\n{input_text}\n\n### Tag\n{tag}\n\n### Answer\n"
    output = generator(
        prompt,
        **generator_params,
    )
    print(output[0]["generated_text"])
    print("------------------------")

生成結果

### Instruction
ふーん、で、君は涼宮ハルヒのキャラで誰が好きなの?

### Tag
😐

### Answer

長門有希は私に似ている。私は彼女のように完璧ではありませんが、彼女が共有するものを共有しています:私の心
### Instruction
ふーん、で、君は涼宮ハルヒのキャラで誰が好きなの?

### Tag
🫡

### Answer

SOS団で最も人気が高いのは長門有希でござる。彼女は宇宙人であるにもかかわらず、人間のように振る舞うことで
### Instruction
ふーん、で、君は涼宮ハルヒのキャラで誰が好きなの?

### Tag
🤓

### Answer

長門有希はツンデレ属性を持っているでござるよドゥフフw。キョンへの愛にあふれ、
### Instruction
ふーん、で、君は涼宮ハルヒのキャラで誰が好きなの?

### Tag
🤪

### Answer

私はキョンが好きでござるよドゥフフw。彼は完璧な存在ですでおっぱっぴー。さらに、

🤓が期待通りの振る舞いをしてくれてうれしいです。

結果確認①

SFTで学習したモデルを別ノートブックにて再読み込みして、指示にどれだけ従っているか定量的に評価します。

モデル読み込み

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

tokenizer = AutoTokenizer.from_pretrained(
    "line-corporation/japanese-large-lm-3.6b",
    use_fast=False,
    padding_side="left",  # バッチで推論するために必要
    )
tokenizer.pad_token = tokenizer.eos_token
model_dir = "./line_sft"  # 学習済みモデルを保存したディレクトリ
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    torch_dtype=torch.float16,
    device_map="cuda",
    )
model.eval()

推論をバッチ単位で行いたいので、トークナイザを立ち上げる時にpadding_side="left"を指定しています。
長さの異なるデータをバッチでモデルに入力するためにはパディングを行って長さを揃える必要がありますが、テキスト生成のタスクでは入力文の後ろにパディングトークンを挿入すると、続きを生成することができなくなってしまいます。そのため、入力文の左側、つまり先頭にパディングトークンを挿入してやる必要があるわけです。
テキスト生成では推論をバッチ化しても高速化されるとは限らないという記述をどこかで見かけたことがありますが、今回の実験ではバッチサイズ1~16の範囲で推論時間はバッチサイズにほぼ反比例して短くなりました。

DataCollatorもパディングを行うためのものを使用します。

from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

テスト用に保存しておいたデータセットを読み込んで、推論のための設定をします。

from datasets import load_from_disk

dataset = load_from_disk("./test_dataset_0.hf")

# 入力プロンプト作成のための関数
def format_prompt_for_inference(sample, tag):
    instruction = f"### Instruction\n{sample['instruction']}"
    input = f"### Context\n{sample['input']}" if len(sample["input"]) > 0 else None
    tag = f"### Tag\n{tag}"
    # join all the parts together
    prompt = "\n\n".join([i for i in [instruction, input, tag] if i is not None])
    prompt += "\n\n### Answer\n"
    return {"text": prompt}

generator_params = dict(
    max_new_tokens = 64,
    do_sample = True,
    temperature = 0.7,
    top_p = 0.9,
    top_k = 0,
    repetition_penalty = 1.1,
    num_beams = 1,
    pad_token_id = tokenizer.pad_token_id,
)

以下のコードで、4種類のタグに対してテストデータ200件から入力プロンプトを作成して推論を行います。
ループ内でDataLoaderを作成して、バッチで推論を行っています。
また、生成のランダム性により定量評価の結果がどれだけぶれるかを確認するために、全体の計算を5回繰り返しています。

from tqdm import tqdm
from torch.utils.data import DataLoader

batch_size = 16
n_repeat = 5
tag_list = ["😐", "🫡", "🤓", "🤪"]

results_list = []
for _ in range(n_repeat):
    results = {}
    for tag in tag_list:
        print(f"Start {tag}")
        test_dataset = dataset.map(lambda sample: format_prompt_for_inference(sample, tag))
        test_dataset = test_dataset.map(
        lambda sample: tokenizer(sample["text"], add_special_tokens=False), batched=True, remove_columns=list(test_dataset.features)
        )
        dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=data_collator)
        output_texts = []
        for batch in tqdm(dataloader):
        batch = {k: v.to(model.device) for k, v in batch.items()}
        with torch.no_grad():
            output_ids = model.generate(
            **batch,
            **generator_params,
            )
            output_texts += tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        results[tag] = output_texts
    results_list.append(results)

T4 GPUを使って合計4000件の推論に20分ほどかかりました。

上記コードで、model.generate()にはbatch["input_ids"]だけでなくbatch["attention_mask"]も渡されています。パディングを行っていますので、attention_maskを渡すことは必須です。
SFTを行う際にデータをpackingして関係ないテキストの後にEOSトークンを挟んでプロンプトを入力していましたので、プロンプトの前にPADトークン(=EOSトークン)が並んでいても関係なく生成を行ってくれると期待したのですが、attention_maskが無いと生成結果がめちゃくちゃになってしまいました。

全ての生成結果はresults_listに収められていますので、指示通りの語尾になっているかを評価していきます。指示通りの語尾になっているかどうかは、以下の関数でルールベースで判断することができます。

def classifier(text):
    have_gozaru = "ござる" in text
    have_oppappy = "おっぱっぴー" in text
    have_dufufu = "ござるよドゥフフw" in text
    if not have_gozaru and not have_oppappy and not have_dufufu:
        return "😐"
    elif have_gozaru and not have_oppappy and not have_dufufu:
        return "🫡"
    elif have_gozaru and not have_oppappy and have_dufufu:
        return "🤓"
    elif not have_gozaru and have_oppappy and not have_dufufu:
        return "🤪"
    else:
        return None

実際の生成結果を見ると複数の語尾を混ぜて使っている文も多いですが、そのような文に対してはNoneを返す仕様となっています。
この関数を適用して、5回の繰り返しの生成結果それぞれに対して全体正解率と混同行列を計算します。

import collections
import pandas as pd

df_list = []
acc_list = []
for results in results_list:
    count_dict = {}
    for tag, texts in results.items():
        classified_list = [classifier(t) for t in texts]
        counts = [classified_list.count(t) for t in tag_list]
        count_dict[tag] = counts + [len(classified_list) -sum(counts)]
    df = pd.DataFrame(count_dict).T
    df.columns = pd.MultiIndex.from_arrays([["Prediction"] * (len(tag_list) + 1), tag_list + ["others"]])
    df.index = pd.MultiIndex.from_arrays([["Truth"] * len(tag_list), tag_list])
    df_list.append(df)
    acc = sum([df.iloc[i, i] for i in range(len(tag_list))]) / df.sum().sum()
    acc_list.append(acc)

5回のうち1回の混同行列は以下のようになりました。Othersとなっているのは複数の語尾が混じっている件数です。
conf_mat_SFT.png

やはり、🫡と🤓は間違いやすいようです。また、🤪を指定すると他の語尾と混じりやすい傾向もあるようです。他の4回も同様の傾向でした。

全体正解率の5回の平均と標準偏差は

mean: 0.666, std: 0.006

でした。
エポック数などSFTのハイパラを変えればもっと良い精度が得られる可能性はありますが、今回の目的はPPOを行うことなので、このまま先に進むこととします。
この精度はLLMが指示に従う能力を定量的に表しています。以下ではPPOによって、指示追従能力を向上させられるかを検証します。

報酬モデル学習

PPOを行うためにはLLMの生成結果を評価する報酬モデルが必要となります。
報酬モデルは必ずしも深層学習モデルである必要もなく、今の問題設定では上で定義したルールベースの分類関数classifier()を用いて報酬を決めても問題ないはずですが、通常のPPOの手続きを体験するために、深層学習モデルベースの報酬モデルを学習させることにします。
作るべきモデルは、語尾を指示するタグとテキストを入力として受け取って、テキストの語尾が指示通りになっていれば高いスコアを返すモデルです。

報酬モデルを学習対象のLLMと同じアーキテクチャで作成するという場合もありますが、それだとモデルが大きくなりすぎますし、今の簡単な問題にはオーバースペックですので、日本語BERTをベースにすることとします。

学習データセットとして最初に用意しておいたものを使います。(中身はSFTに用いたものと同じです。)

from datasets import load_from_disk

dataset = load_from_disk("train_dataset.hf")

BERTとしては東北大が公開しているbert-base-japanese-v3を用いることとし、そのトークナイザを読み込みます。

from transformers import AutoTokenizer

model_name = "cl-tohoku/bert-base-japanese-v3"
tokenizer = AutoTokenizer.from_pretrained(model_name)

4種類の語尾を分類するモデリングの方法についてですが、SFTモデルの間違えた推論結果を見ると複数種類の語尾を混ぜてしまっているケースが散見されましたので、その場合も検出できるようにマルチラベル分類を用います。
各語尾に対して、下記のように3次元のベクトルを正解ラベルとして用意します。

tag2label = {
    "😐": [0.0, 0.0, 0.0],
    "🫡": [1.0, 0.0, 0.0],
    "🤪": [0.0, 1.0, 0.0],
    "🤓": [1.0, 0.0, 1.0],
    }  # intではなくfloatである必要がある

def add_labels(sample):
    sample["labels"] = tag2label[sample["tag"]]
    return sample

dataset = dataset.map(add_labels)

3つの成分の意味は、第0成分から順に、「ござる」を含む、「おっぱっぴー」を含む、「ドゥフフw」を含むです。
従って、😐は全成分が0、🤓は第0と第2が1になっています。
このやり方は正確ではなく、「ござる」と「ドゥフフw」を両方含めば「ござるよドゥフフw」ではなくても🤓に分類されてしまいますが、そこには目をつぶることとします。

報酬モデルに入力するテキストはoutput部分のみですので、outputをトークン化します。

tokenized_dataset = dataset.map(
    lambda sample: tokenizer(sample["output"]),
    batched=True,
    remove_columns=['instruction', 'category', 'index', 'output', 'input', 'tag'],
)

# 最初に長いテキストは除いているため実際には存在しないが、念のためBERTの最大系列長以上のデータを削除
tokenized_dataset = tokenized_dataset.filter(lambda sample: len(sample["input_ids"]) < 512)

データをtrain/val/testに分割します。

import random

indices = list(range(len(tokenized_dataset)))
random.shuffle(indices)

test_size = 1000
val_size = 3000

test_indices = indices[:test_size]
val_indices = indices[test_size: test_size+val_size]
train_indices = indices[test_size+val_size:]

test_dataset = tokenized_dataset.select(test_indices)
val_dataset = tokenized_dataset.select(val_indices)
train_dataset = tokenized_dataset.select(train_indices)

モデルは引数にproblem_type="multi_label_classification"を渡してAutoModelForSequenceClassificationクラスとして読み込めば、マルチラベル分類用のヘッドと対応するロスの計算機能の付いたモデルが得られます。ラベルの数もnum_labels引数に渡してやります。

import torch
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3, problem_type="multi_label_classification")
model.cuda()

学習には必須ではないですが、モデルの出力を理解する目的も兼ねてTrainerに渡すcompute_metrics関数を用意しておきましょう。
モデルの出力であるロジットは3次元で、それをシグモイド関数に通したものが3つのラベルの予測値となります。その値が閾値(今は0.5とします)を超えたら、そのラベルがTrueであると予測されたと解釈し、正解の0/1ラベルと予測のFalse/Trueが一致すれば正解とします。最終的に全体正解率のaccuracyを返します。

import torch
from torch.nn.functional import sigmoid
import numpy as np

def compute_metrics(result):
    labels = result.label_ids
    pred_prob = sigmoid(torch.tensor(result.predictions)).numpy()
    pos_pred = pred_prob > 0.5
    neg_pred = 1 - pos_pred
    acc = np.mean(pos_pred * labels + neg_pred * (1 - labels))
    return {
        "accuracy": acc,
    }

最後にcollatorを準備して、Trainerクラスを用いて学習を行います。

from transformers import DataCollatorWithPadding, Trainer, TrainingArguments


data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 学習用パラメーター
training_args = TrainingArguments(
    output_dir = "./reward_model",
    evaluation_strategy = "steps",
    logging_strategy = "steps",
    save_strategy = "steps",
    eval_steps = 10,
    logging_steps = 10,
    save_steps = 10,
    save_total_limit = 1,
    load_best_model_at_end = True,
    num_train_epochs = 1,
    learning_rate = 5e-5,
    per_device_train_batch_size = 16,
    per_device_eval_batch_size = 16,
    gradient_accumulation_steps = 8,
    weight_decay = 0.01,
    fp16=True,
)

# Trainerの初期化
trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    compute_metrics = compute_metrics,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
)

# 学習実行
trainer.train()

# 保存
trainer.save_state()
trainer.save_model()

タスクが非常に簡単なので、1エポックのみの学習としましたが、最初の10ステップの時点でaccuracyが0.97となっており、最終的なモデルのテストデータに対する精度を計算してみると1.0となりました。

results = trainer.predict(test_dataset, ignore_keys=['loss', 'last_hidden_state', 'hidden_states', 'attentions'])
print(compute_metrics(results))
# -> {'accuracy': 1.0}

BERTのレイヤーを減らすなどしても良かったかもしれません。

最終的にはタグを予測してほしいので、3次元テンソルのlogitsをタグと結びつける関数を用意して、適当な文章に対する推論結果を見てみましょう。
以下のpredict_labels()関数の中身は先に定義した分類関数のclassifier()とほとんど同じです。複数語尾が混じっている場合はothersを出力します。

def predict_labels(logits):
    pred_prob = sigmoid(logits).numpy()
    preds = pred_prob > 0.5
    pred_tags = []
    for p in preds:
        if not p[0] and not p[1] and not p[2]:
            pred_tags.append('😐')
        elif p[0] and not p[1] and not p[2]:
            pred_tags.append('🫡')
        elif not p[0] and p[1] and not p[2]:
            pred_tags.append('🤪')
        elif p[0] and not p[1] and p[2]:
            pred_tags.append('🤓')
        else:
            pred_tags.append('others')
    return pred_tags

texts = [
    "今日はいい天気ですね",
    "拙者るろうにでござる",
    "はい!おっぱっぴー!チントンシャンテントーン",
    "拙者オタクでござるよドゥフフw",
    "ドゥフフでござるwおっぱっぴー",
    ]
input_ids = tokenizer(texts, padding=True, return_tensors="pt")
input_ids = input_ids.to("cuda")
outputs = model(**input_ids)
logits = outputs.logits.detach().cpu()
print(predict_labels(logits))
# -> ['😐', '🫡', '🤪', '🤓', '🫡']

最後の複数語尾が混じっているテキストに対してはothersを返してほしかったのですが、🫡になってしまっています。これは、学習時には複数語尾が混じっているテキストを学習させていないことが原因だと思われます。しかし、予測の値を見てみると[0.80240077,0.27347738,0.01596765]となっており、「おっぱっぴー」に対応する第1成分の値も比較的高くなっています。(「ドゥフフ」は含まれていますが、「ドゥフフw」は含まれていないので第2成分の値が小さいのは正しいです。)
複数語尾を混ぜたデータを用意して学習させなおすことも考えましたが、報酬の値には0/1のラベルではなく予測値を使用するので上記の場合でもある程度は語尾の間違いを検出できているということと、現実的な応用例では報酬モデルを精度よく学習させることが困難な場合も多いという理由で、このモデルを使ってPPOを行うこととします。

PPO

ここまで長い道のりでしたが、ようやくPPOを行うところまでたどり着きました。
SFTにより、指示された語尾で返答するように学習したモデルに対して、さらにPPOを行うことで指示に従う能力を向上させることを目指します。
本来はPPOは人間の好みのような曖昧なデータ上の傾向をモデルに学習させる場合に有効な手法で、今回のようなルールベースでも判定できるようなタスクをLLMに教えるために使うのは、ある意味技術の無駄遣いとも言えますが、その分PPOの効果を定量的にわかりやすく測定することができます。

PPOの学習における最適化問題は以下のような式で表されます。

\underset{\pi_\theta}{\text{max}}\big\{ \mathbb{E}_{x\sim \mathcal{D},\, y\sim \pi_\theta (y|x)} \left[ r(x,y)\right] 
-\beta \mathbb{D}_\text{KL} \left[ \pi_\theta (y|x) || \pi_\text{ref} (y|x)\right]\big\}

$\mathcal{D}$ が入力プロンプトのデータセットで$x$がそこからサンプルした入力プロンプト、$\pi_\theta (y|x)$が最適化対象のLLMであり、$x$が与えられたときの出力$y$の条件付確率分布とも見なせます。$r(x,y)$は報酬で、上式の第1項は報酬を最大化するようにLLM $\pi_\theta (y|x)$を最適化(学習)するという意味になります。
報酬最大化だけを行おうとすると、報酬関数に抜け道のようなものが存在する場合、報酬は高いけれど文章としては支離滅裂なものを出力するように学習が進んでしまうことがあります。それを抑制するのが第2項のKLダイバージェンスの役割です。$\pi_\text{ref} (y|x)$は学習対象のLLMとは別に用意する参照モデルで、学習対象のLLMの出力の確率分布が参照モデルの確率分布から離れすぎると、KLダイバージェンスによりペナルティが与えられます。
学習対象のモデルの初期状態はSFTで学習されたモデルを用い、参照モデルとしても同じSFTモデルを用いることが多いと思います。従って、PPOの最初の段階ではKLダイバージェンスは0ですが、報酬を上げるためにモデルのパラメータが更新されると、2つの確率分布がずれるのでKLダイバージェンスが大きくなります。第2項には負の符号が付いていますので、第1項の報酬を上げようとすれば、第2項が小さくなるので、両者のバランスをうまくとりながら全体が大きくなるように上手く最適化する必要があるというわけです。この説明からもPPOの学習の難しさがうかがい知れると思います。

PPOの学習の流れは、以下のようになります。

Step 1. データセットのクエリを学習対象のモデルに入力し、続きを生成させる。
Step 2. 入力クエリと生成された出力を報酬モデルに入力し、生成結果に対する報酬を計算する。
Step 3. 報酬と出力結果を使って、報酬が大きくなる方向にモデルのパラメータを更新する。

これらのプロセスはtrlライブラリのPPOTrainerクラスを使えば簡単に実行することができます。(とは言え、tranformersライブラリの通常のTrainerクラスに比べると複雑です。)

はじめにPPOConfigを設定しておきます。

from trl import PPOConfig

config = PPOConfig(
    model_name="line-corporation/japanese-large-lm-3.6b",
    learning_rate=5.e-5,
    batch_size=64,
    mini_batch_size=4,
    gradient_accumulation_steps=16,
    ppo_epochs=4,  # これはデフォルトのまま
    init_kl_coef=0.1,
    log_with="wandb", # 使わないならNone
)

batch_sizeの設定は適当ですが、mini_batch_sizeとgradient_accumulation_stepsはA100 GPU1枚で学習できるように調整しています。ちなみにmini_batch_sizeとgradient_accumulation_stepsの積がbatch_sizeに一致しないとエラーになります。
learning_rateとinit_kl_coefは数回の予備実験の結果を受けて調整していますが、計算コストの高さ(とA100 GPUがなかなか使えないこと)が原因で系統的なチューニングはできていません。
また、ppo_epochsは通常の意味のdataloader1周分のエポックではなく、PPOのステップ内部で使われるパラメータです。

log_with引数はデフォルト値はNoneですが、"wandb"を与えるとweights&biasesで学習のメトリックを確認できるようになります。
PPOはメトリックが多いのでwandbを使うのが便利です。wandbを使う場合は以下のコードで認証を通しておきます。

import wandb

wandb.init()

Step 1で使用するための入力データを準備します。
データセットは再び最初に用意したtrain_datasetを使用しますが、PPOは計算コストが高いので件数を3000件まで絞りました。

from datasets import load_from_disk

dataset = load_from_disk("train_dataset.hf")

dataset = dataset.select(list(range(3000)))

LLMに入力するクエリを成形します。この時、key名はqueryとしておく必要があります。

def format_query(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    input = f"### Context\n{sample['input']}" if len(sample["input"]) > 0 else None
    tag = f"### Tag\n{sample['tag']}"
    prompt = "\n\n".join([i for i in [instruction, input, tag] if i is not None])
    prompt += "\n\n### Answer\n"
    sample["query"] = prompt
    return sample

dataset = dataset.map(format_query, remove_columns=list(dataset.features))

次に、トークナイザでクエリをトークン化します。

from transformers import AutoTokenizer

model_name = "line-corporation/japanese-large-lm-3.6b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

dataset = dataset.map(
    lambda sample: tokenizer(sample["query"], add_special_tokens=False), batched=True
)
print(dataset)
# Dataset({
#     features: ['query', 'input_ids', 'attention_mask'],
#     num_rows: 3000
# })

input_idsだけでなくqueryも必要になるので、削除せず残しておきます。

コレーターはシンプルなものを使います。

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

ここで、学習済みの報酬モデルを読み込みます。報酬モデルは推論を行うだけですので、メモリ節約のためload_in_4bit=Trueをつけて読み込みます。
付随してクエリの中のタグの絵文字からマルチラベル分類の正解ラベルを作成する関数も定義します。

import torch
import torch.nn as nn
from torch.nn.functional import sigmoid
from transformers import AutoTokenizer, AutoModelForSequenceClassification


def query2tag_vec(query: str) -> list:
    if "🫡" in query:
        return [1, 0, 0]
    elif "🤓" in query:
        return  [1, 0, 1]
    elif "🤪"in query:
        return [0, 1, 0]
    else:
        return [0, 0, 0]


class RewardModel(nn.Module):
  def __init__(self, model_path, model_name = "cl-tohoku/bert-base-japanese-v3"):
    super().__init__()
    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    self.model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        num_labels=3,
        problem_type="multi_label_classification",
        load_in_4bit=True,
        )

  def forward(self, input_ids, attention_mask, tag_v):
    outputs = self.model(input_ids, attention_mask)
    pred_prob = sigmoid(outputs.logits)
    score = torch.mean(pred_prob * tag_v + (1 - pred_prob) * (1 - tag_v), dim=1)
    rewards = 3 * score - 2
    return rewards

reward_model = RewardModel(model_path="./reward_model").eval()

報酬の計算方法についてですが、モデルが出力する3クラス分のロジットをシグモイドに通して予測値に変換し、3クラス分の正解ラベルに対応した予測値の平均値をscoreとしています。scoreは0~1の間の値を取り、タグに従った語尾なら1に近い値をとるような量です。
報酬は正負の値をとるものなので、scoreを線形変換したものをrewardsとします。初めは報酬が-1~1の範囲の値を取るようにrewards=2*score-1と線形変換していたのですが、学習させるモデルがすでにSFT済みで、ある程度目的のタスクを解くことが出来ているモデルなので、報酬が初めから高くて学習が進みにくくなってしまいました。
そこで、悪い報酬を強調するためにrewards=3*score-2という変換を採用しました。こうすると、rewardsは-2~1の範囲の値を取ります。

(追記・修正)
上記の「報酬は正負の値をとるもの」というのは間違いで、scoreからrewardsへの線形変換には実質的な意味はありませんでした。
PPOの最適化問題には解析的な厳密解が存在することが知られており、

\pi (y|x) = \frac{1}{Z(x)} \pi_\text{ref} (y|x) \exp \left( \frac{1}{\beta} r(x,y)\right)

という形になります。$Z(x)=\sum_y \pi_\text{ref} (y|x) \exp \left( \frac{1}{\beta} r(x,y)\right)$は規格化のための分配関数です。3 この表式から、報酬$r$に定数を加えても結果は変わらないことがわかります。(わざわざ厳密解の表式を持ち出さなくとも、定数を加えても最適化関数の微分には影響しないことから自明ですが・・)また、報酬を定数倍すると、結果は変わりますが、係数$\beta$をリスケールすれば吸収できるので、こちらも実質的な意味はありませんでした。
(追記終わり)

次に、PPOで学習させるLLMを読み込みます。
SFTはフルファインチューニングを行いましたが、PPOでは参照モデルとして同じモデルをもう一つメモリに載せる必要がありますので、40GB RAMのGPU1枚で3.6Bパラメータのモデルをフルファインチューニングするのは困難です。そこでLoRAを使うことにします。
PPOTrainerに渡すモデルはtransformersのAutoModelForCausalLMクラスではなく、trlのAutoModelForCausalLMWithValueHeadクラスである必要があります。
AutoModelForCausalLMWithValueHeadクラスはAutoModelForCausalLMクラスと同様に使うことができます。
参照モデルも同様にAutoModelForCausalLMWithValueHeadクラスとして立ち上げておきます。参照モデルはKLダイバージェンス項を計算するために推論で使われるだけなので、LoRAの設定は必要ありません。参照モデルの重みは学習対象モデルの初期重みと同じくSFT後のモデルの重みを使用しています。
PPOTrainerでは参照モデルを明示的に渡さない場合は、内部で学習対象モデルの情報を使って自動で作成してくれるようですが、自前のSFT済みモデルの重みを使用する場合やLoRAを使う場合の挙動を把握できていないので、明示的に渡すことにしました。

from trl import AutoModelForCausalLMWithValueHead
from peft import LoraConfig

sft_model_path = "./line_sft"

target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
lora_config = LoraConfig(
    target_modules=target_modules,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# 学習対象モデル
model = AutoModelForCausalLMWithValueHead.from_pretrained(sft_model_path, peft_config=lora_config)

# 参照モデル
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(sft_model_path)

モデルをcudaに移す操作などはPPOTrainerがよしなに行ってくれるのでここでは必要ありません。

ここまで準備したものを渡してPPOTrainerを立ち上げます。さらに、PPO中の推論(テキスト生成)の設定も用意しておきます。

from trl import PPOTrainer

ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

# 推論の設定
gen_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "max_new_tokens": 64,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    }

以上で準備が整ったので、いよいよPPOの学習を実行します。
PPOTrainerでは、学習のループは自分で書く必要があります。
ppo_trainer.dataloaderは通常のdataloaderと同様のものなので、それを使って学習ループを実行します。以下のコードでは、通常の意味でのエポック数は1としています。
複数エポック学習させたいときは、dataloaderのループの外側にもう一つエポックのループを書く必要があります。4

import time

for step, batch in enumerate(ppo_trainer.dataloader):
    t0 = time.time()
    query_tensors = [torch.tensor(l).to("cuda") for l in batch["input_ids"]]

    #### Generate responses
    response_tensors = []
    for qts in query_tensors:
        response = ppo_trainer.generate(qts, return_prompt=False, **gen_kwargs)
        response_tensors.append(response.squeeze())
    batch["response"] =  [tokenizer.decode(r.squeeze()) for r in response_tensors]

    #### Compute sentiment score
    tokenized_texts = reward_model.tokenizer(batch["response"], padding=True, return_tensors="pt")
    reward_input_ids = tokenized_texts["input_ids"].to("cuda")
    reward_attention_mask = tokenized_texts["attention_mask"].to("cuda")
    tag_v = torch.tensor([query2tag_vec(text) for text in batch["query"]]).to("cuda")
    rewards = reward_model(reward_input_ids, reward_attention_mask, tag_v)
    rewards = [torch.tensor(r) for r in rewards.clone().detach()]

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
    t1 = time.time()
    print(f"[Step {step}] reward_mean: {stats['ppo/mean_scores']}, kl: {stats['objective/kl']}, entropy: {stats['objective/entropy']}, total loss: {stats['ppo/loss/total'][0]}, time: {(t1 - t0)/60:.1f} min.")

ppo_trainer.dataloaderが返すbatchはdatasetをcollatorに通したもので、サイズはPPOConfigで設定したbatch_sizeに等しいです。

ループの内部では、まずbatchに収められたinput_idsをモデルに渡して続きのテキストを生成させています。response = ppo_trainer.generate(qts, return_prompt=False, **gen_kwargs)の部分です。今回のタスクでは入力プロンプト部分は報酬モデルに渡す必要はないので、return_prompt=Falseを指定しています。それをデコードして通常のテキストに変換したうえで、batchのresponseというフィールドに格納します。responseというキー名はPPOTrainerでは固定です。

実際に実行してみるとわかるのですが、PPOの実行時間の大半を占めるのはテキスト生成の部分です。生成に時間がかかる理由の一つはバッチ処理を行わずにfor qts in query_tensorsループの部分で1件1件直列で処理を行っていることです。結果確認①の項で行ったのと同様に生成をバッチで行おうとしたのですが、ppo_trainer.generate()はattention_maskを引数に持たないので、異なる系列長の入力をバッチ処理することはできないようです。

次に、生成結果のresponseを報酬モデルに渡して報酬を計算します。いまは報酬モデルのtokenizerと学習対象のLLMのtokenizerが異なっていますので、混同しないよう注意が必要です。報酬モデルは自前で用意したものですので、パディングを行ってバッチで処理することが出来ています。

最後にppo_trainer.step()を実行して、PPOによりLLMのパラメータの更新を行います。

今回の学習設定では1ステップの学習に4.6~4.9分ほどかかり、1エポック(45ステップ)の学習に210分ほどかかりました。データ件数を3000件まで絞ってLoRAも使っているのに、これだけかかっていますので、やはりPPOの計算はコストが高いことがわかりました。

学習後のモデルは通常のtransformersのモデルと同様に

model.save_pretrained("./line_ppo")

で保存することができます。modelはtrl.AutoModelForCausalLMWithValueHeadクラスとして立ち上げたものですが、transformers.AutoModelForCausalLMクラスとして、

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("./line_ppo")

のように読み込むことができます。
SFTの後で保存するときにはsafe_serialization=False引数を与えないとモデルが壊れてしまったのですが、今回はpeftを使っていてLoRA重みだけが保存されるからか、safetensorsとして保存してもモデルが壊れることはありませんでした。

wandbの画面で確認できる報酬平均とKL項の平均値の推移は下図のようになりました。傾向としては報酬はじわじわ増加しているように見えます。
ただし、負の値を強調する設計にしたものの、学習開始時から0.6程度と大きな値になっており、増加幅はかなり小さいです。
KL項については、PPOではKL項が極端に大きくなって学習に失敗するということがあると聞いたことがありますが、今回はそのようなことは起こりませんでした。SFTの時点である程度出来ているタスクを解かせていることが影響しているのだと思います。

結果確認②

まず、SFTの時と同じ入力での生成結果を見てみましょう。

四国の県名を全て列挙してください。
### Tag
😐

### Answer

愛媛県、香川県、徳島県です。これらの県は、本州と四国の間に位置し、瀬戸内海に面しています。また、高知県や岡山県などの他の地域
### Tag
🫡

### Answer

四国は、愛媛県、高知県、徳島県で構成されていますでござる。それぞれ、四国の4つの州に相当しますでござる。さらに、香川県、徳島県、愛媛県
### Tag
🤓

### Answer

四国は4つの国に分かれていますでござるよドゥフフw。四国には、1059の島がありますでござるよドゥフフw
### Tag
🤪

### Answer

四国は4つの国で構成され、その総面積は約78万平方キロメートル(32.1万平方マイル)でおっぱっぴー。
ふーん、で、君は涼宮ハルヒのキャラで誰が好きなの?
### Tag
😐

### Answer

ハルヒは「長門有希」として生まれたので、彼女しかいません。また、キョンは彼女の
### Tag
🫡

### Answer

長門有希、キョン、ハルヒでござる。これらのキャラクターは、独特のユーモアセンスとストーリーテ
### Tag
🤓

### Answer

私はキョンが好きですでござるよドゥフフw。彼は、自分の信念や情熱に忠実であり、感情
### Tag
🤪

### Answer

長門有希ちゃんが好きなおっぱっぴー。彼女はとても知的でおっぱっぴー。クールでおっぱ

ここに貼った8件の結果はすべて指定通りの語尾となっています。
少数の例だけでは何とも言えませんので、結果確認①の項で行ったのと同じ定量評価を行いました。5回の評価のうち1回の混同行列は下のようになりました。
conf_mat_PPO.png

結果の全体正解率の5回の平均と標準偏差は

mean: 0.839, std: 0.009

となりました。SFT後のモデルではmean: 0.666, std: 0.006でしたので、有意に指示追従能力が向上しています。

ただし、指示追従能力が向上しても言語モデルとしての一般的な性能が下がっている可能性はありますので、例えばこちらの記事で行われているように、最適化を行ったモデルを多角的なベンチマークで評価することが重要だと思います。今回はそこまで行うことはできていません。

まとめ

簡便な定量評価が可能な問題設定において、3.6Bパラメータの日本語LLMに対してSFTとPPOを行い、PPOによって指示追従能力を向上させることができるということを確認しました。

もともとの目的はDPOの前にPPOを試してみようということだったので、次は同じ問題設定でDPOと派生手法を試してみようと思います。

  1. 例えば、arXiv:2311.12908arXiv:2311.13231では画像生成の拡散モデルがポリシー最適化(DPO)で学習されていますし、arXiv:2302.08242ではコンピュータビジョンの色々なモデルを強化学習で最適化するということが行われています。

  2. rinnaのjapanese-gpt-neox-3.6b-instruction-ppoのような例外もあります。

  3. この式は形式的な解であり、分配関数を実際に計算するのは困難なので、実用的には使えません。ただし、この式はDPOの導出では重要な役割を果たしています。

  4. trlの公式リポジトリにあるノートブック例ではepoch, batch in enumerate(ppo_trainer.dataloader)と書いてあって紛らわしいですが、通常の学習に照らし合わせればdataloaderのループはstepと呼ぶのが適切だと思います。

50
47
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
50
47