8
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

QAタスクでRinna 3.6BをファインチューニングしてChatGPTの性能を目指す

Last updated at Posted at 2023-08-06

はじめに

 2023年も様々な大規模言語モデル(LLM)が公開されています。ビッグテックによる多言語対応モデルだけでなく、2023年の5月にはサイバーエージェント社とrinna社がそれぞれ日本語特化のLLMを公開しました。これらの日本語特化モデルはローカルにモデルをダウンロードして自由に動かすことができます。そのため、オンプレ環境でモデルを自由にカスタマイズしたり、APIに投げたくないデータを扱った学習などを行うことができます。

 ただ、これらの日本語特化モデルを実際に動かした方たちの記事を拝見した感じでは、現状では文章生成の性能はOpenAIのモデル(ChatGPT, GPT-4等)に対してかなり劣っているような印象を受けました。そのため、これらの日本語特化モデルを使用する際、場合によってはファインチューニングが必要になるかと思います。しかし、個人でこれらのモデルをファインチューニングしてChatGPT等と同等性能のモデルを構築するのはけっこう厳しいものがあると思います。ただ、様々なタスクでzero-shotやfew-shotでChatGPT等と同じレベルのことをさせるのは難しいかもしれませんが、特定のタスクに限定してファインチューニングさせることで、そのタスクにおいてはChatGPT等に近い性能を実現できるのでは?と思いました。

 そのため、この記事では、rinna社のjapanese-gpt-neox-3.6b-instruction-ppoモデルでQAタスクのファインチューニングを行い、ChatGPT等のモデルの性能にどの程度近づけることができたかを書きたいと思います。

使用データ

 データにはJGLUEのJSQuADを使用しました。JGLUE(Japanese General Language Understanding Evaluation)は、自然言語処理の有名なベンチマークのGLUEの日本語版であり、Yahoo JAPAN社の研究所と早稲田大学 河原研究室が構築・公開して下さったものです。今回はその中のQAタスクであるJSQuADを使用させていただきました。

 JSQuADは文章を読み、その文章に対する質問に回答するというタスクのデータセットです。JSQuADの学習データは例えば以下のようになっています(shunk031様がHuggingfaceのdatasetsにJGLUEを登録して下さっているため、今回はそちらを利用させていただいています)。

from datasets import load_dataset

dataset = load_dataset("shunk031/JGLUE", name="JSQuAD")

data = dataset['train'][1000]
print('文章:', data['context'])
print('質問:', data['question'])
print('解答:', data['answers']['text'][0])
文章: 高度 [SEP] 人体は、呼吸や心拍数を速め、血液組成を変化させて高度に順応することができる。高度への順応には、数日から数週間を要する。しかし、8,000mを超えると、人体は適応できず、死に至ることもある。
質問: 高度への順応は何mを超えると、人体は適応できなくなるか?
解答: 8,000m

 Trainセットとして62,859件、Devセットとして4,442件が提供されています。ファインチューニングではTrainセットを使用し、評価時にDevセットを使用しました。

各種モデルのスコア

 既にいろいろな方や組織がJGLUEを使用した言語モデルの評価を実施・公開して下さっています。まず、JGLUEのデータセットと併せてBERT系のモデルのスコアが公開されています。また、Stability-AI社のリポジトリでは日本語特化モデル含めた様々なGPT系モデルの性能が公開されており、TIS社の技術ブログでもChatGPTなどの評価結果を公開して下さっています。

 各サイトから評価結果を一部抜粋させていただき、今回使用するJSQuADのスコア(Exact Match)をグラフにまとめました。なお、それぞれの結果は評価時の条件(プロンプトの内容や使用した評価データのサンプルサイズなど)が異なっているため、各モデルの性能差を正確に表したものではないことに注意して下さい。
JSQuAD性能比較.png
<下記リンク先の情報をもとに作成>
[1] JGLUE: 日本語言語理解ベンチマーク
[2] JP Language Model Evaluation Harness
[3] ChatGPT vs BERT:どちらが日本語をより理解できるのか?

 [1]のBERT系のモデルはTrainセットでファインチューニングしています。一方、[2], [3]のGPT系のモデルはファインチューニングをしておらず、few-shotでのスコアとなっています。この結果を見る限り、ChatGPTはfew-shotで高いスコアを出していますが、ファインチューニング済みのBERT系モデルの精度にはあと少し届いていません。また、[2]のようなローカルで動かせるGPT系のモデルたちとChatGPTの間には現状ではそこそこ大きなスコアの差があるようです。

 そして、今回はfew-shotで約半数を正解させているRinna社のjapanese-gpt-neox-3.6b-instruction-ppoに対してファインチューニングを実施し、ChatGPTのスコアを目指します。

プロンプトの作成

 ファインチューニングに向け、JSQuADのデータをプロンプトに変換します。今回使用するモデルはRLHF(人間のフィードバックからの強化学習)をさせており、プロンプトは"ユーザー"と"システム"による対話の形式を取ります。ファインチューニングすることを前提として、以下のようにシンプルな文言を使用したプロンプトに変換しています。

def prompt_template(data, dataset_type):
    context = data['context']
    question = data['question']
    if dataset_type == 'train':
        answer = data['answers']['text'][0]
    elif dataset_type == 'test':
        answer = ''
    start_idx = context.find(' [SEP] ') + 7
    result = f"ユーザー: 質問に対する回答を文章から抽出してください。<NL>文章:{context[start_idx:]}<NL>質問:{question}<NL>システム: {answer}"
    return result

prompt = prompt_template(dataset['train'][1000], 'train')
print(prompt.replace('<NL>', '\n'))
ユーザー: 質問に対する回答を文章から抽出してください。
文章:人体は、呼吸や心拍数を速め、血液組成を変化させて高度に順応することができる。高度への順応には、数日から数週間を要する。しかし、8,000mを超えると、人体は適応できず、死に至ることもある。
質問:高度への順応は何mを超えると、人体は適応できなくなるか?
システム: 8,000m

 JSQuADは文章の頭にその文章のタイトルが入っているためその部分を削除し([SEP]の前部分)、文章と質問の間などには改行を表す<NL>を挿入しています。そして、学習データには"システム"の発話として正解の文字列をセットしています(評価データでは与えない)。また、ファインチューニングを前提としているため、今回は文章/質問/解答のサンプルを与えず、zero-shot形式のプロンプトにしました。

LoRAファインチューニング

 LLMのファインチューニングでは、効率的に一部のパラメータを更新させて下流タスクに適応させるPEFT(Parameter-Efficient Fine-Tuning)というアプローチがよく取られます。今回もPEFTのひとつであるLoRA(Low-Rank Adaptation)でモデルを学習させます。LoRAは事前学習済みモデルに対する低ランク行列を用いて効率的に学習を行います。もとの事前学習済みモデルのパラメータは更新せずに、追加したパラメータのみを学習させることで、学習コストを大幅に抑えることができます(下の図のオレンジ色のパラメータを更新する)。
LoRA.png
引用:LoRA: Low-Rank Adaptation of Large Language Models のFigure 1

 実装については、Rinna社のjapanese-gpt-neox-3.6bモデルでLoRAファインチューニングを試されているnpaka様のコードを参考にさせていただきました。本来であれば使用するモデルやデータに合わせて各種設定値を適切に決める必要がありますが、今回は低ランク行列の次元数などの設定値含めコードをほぼそのまま使用させていただきました。

import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

model = AutoModelForCausalLM.from_pretrained(
    'rinna/japanese-gpt-neox-3.6b-instruction-ppo',
    load_in_8bit=True,
    device_map='auto'
)

lora_config = LoraConfig(
    # Lora attention dimension
    r=8,
    # Lora scaling
    lora_alpha=16,
    # The names of the modules to apply Lora to
    target_modules=["query_key_value"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

 こちらのLoRAモデルをJSQuADのTrainセット(62,859件)で2エポック学習させました。Google Colab上でTesla A100 GPUで学習させたのですが、8時間ほどかかりました。

結果

 ファインチューニングさせたモデルで評価を行ったところ、Exact Matchで正解率は 約64.7% となりました。先程の各種モデルのスコアでも取り上げた通りStability-AI社の評価ではfew-shotで 51.8% だったため、ファインチューニングすることでfew-shotと比較して13%程度高いスコアを出せたことになります(今回定義したプロンプトでファインチューニング前のモデルを評価した際は正解率が約13.2%だったため、そこと比較すると大幅に精度が向上している)。

 ということで、残念ながら今回はChatGPTの81.8%というスコアには全然届きませんでした。原因は色々考えられますが、まずファインチューニングの内容が不十分ということがあげられると思います。今回はとりあえずLoRAを試してみたといった感じで、各種設定値や学習方法も今回のモデルやデータセットに合っていなかった可能性があります。もっと試行錯誤をすることで、もう少しスコアを上げることはできたかなと思います(LoRAといえどそこそこ学習コストがかかってしまい、あまり試行錯誤する余裕がなかったというのが実情です...)。

 また、実際にRinna 3.6Bをファインチューニングした素直な感想として、今回のようなタスクであればBERT系のモデルがやっぱりコスパが良さそうだなと感じました。しかし、まだまだ精度を上げる余地はあると思いますので、もう少し挑戦してみようかなと思っています。

おまけ

 QAタスクのファインチューニング前後で、通常の文章生成にどのような変化が起きるのかを確認してみました。以下のプロンプトで生成された文章を比較します。

prompt = "ユーザー: 自然言語処理について教えて下さい。<NL>システム:"

ファインチューニング前

もちろんです。自然言語処理とは、人間の話し言葉や書き言葉を理解し、処理する技術のことです。自然言語処理は、人工知能研究の中心的な分野の1つであり、様々な研究が行われています。

ファインチューニング後

自然言語処理とは、コンピュータが人間とやり取りする際に使用する言葉を理解し、その意図を推論する技術である。この技術は、様々な文脈で人間の会話の理解やコミュニケーションに役立っている。

 ファインチューニング前の方が自然な文章に感じますが、JSQuADに引っ張られておかしな内容になるという事態はそこまで起きてない?ように見えます。

参考

https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-ppo
https://github.com/yahoojapan/JGLUE
https://www.anlp.jp/proceedings/annual_meeting/2022/pdf_dir/E8-4.pdf
https://huggingface.co/datasets/shunk031/JGLUE
https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable)
https://fintan.jp/page/9126/
https://arxiv.org/abs/2106.09685
https://note.com/npaka/n/nc387b639e50e

8
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?