概要
TrainTVの大喜利を見ていて、これは面白い?と思うことがあり、AIを使って大喜利を自動生成できないかと試してみることにしました。今回、大喜利AIの作成には、小規模ながらも高性能なSakanaAIのTinySwallow-1.5Bを利用してみました。
1. 大喜利データセット
そもそも大喜利のデータセットなんてあるのか?と思いながらHuggingfaceのデータセットで検索してみた。「ogiri」で探す。2種類あるみたいです。YANS-officialさんのは画像から一言という内容です。今回利用するのは、watashihakobashiさん(「ワタシはコバシ」さんと読むのかな?)のデータセットとなります。

データセットカードには cc-by-nc-nd-4.0
で「教育、研究目的でのみ使用してください。」と表記されている!興味本位の利用はNGなのだろうか?教育利用かつ非商用なのでOKな気もするが、連絡して確認を取りたい。
とりあえず、データの中身を確認すると、次のようなものになります。
なぜそうなる!?みたいな返答もありますが、大喜利感あるかも。
サンプル
お題目:「iPhone」の意外な使い道とは?
ボケ:マカロニの穴をあける
お題目とボケの部分をLLMに学習させてみたいと思います。教師あり学習の演習として考えてください。
2. SakanaAI/TinySwallow-1.5B について
今回利用するベースモデルはsakana.aiが開発した小規模日本語言語モデルの「TinySwallow-1.5B」です。詳細は、sakana.aiのブログ記事に書かれています。
知識蒸留するときに様々な工夫をこらして作成したモデルのようで、日本語が高性能!開発した方々に感謝 モデルカードの部分を見ると、モデルサイズは1.5B。本当にすごいです。このサイズで高性能とは!ちなみに、ベースモデルはQwen/Qwen2.5-1.5Bみたいです。
sakaia.aiから指示チューニング済みのモデルも用意されています。これを利用すれば高度な対話が可能なはず。学習しなくてもうまく動作しそうなのですが、自身の演習(?)を兼ねているので、instruct前のモデルで作成したいと思います。
3. 指示チューニング
TinySwallow-1.5Bにwatashihakobashi/ogiriデータを学習させていきます。
3.1 主要なライブラリと流れ
Huggingfaceのライブラリー群を利用します。CausalLMの教師あり学習なので基本的に次の4つを使います。他の細かいライブラリはコードを参照してください。
- AutoModelForCausalLM
- AutoTokenizer
- SFTConfig
- SFTTrainer
各ライブラリの概要を簡単にまとめておきます。
ライブラリ名 | 所属パッケージ | 主な用途 | 概要 |
---|---|---|---|
AutoModelForCausalLM | transformers | 言語モデルの読み込みと推論 | 事前学習済みの因果言語モデルを簡単に読み込むためのクラス。自動でモデルクラスを選択します。 |
AutoTokenizer | transformers | テキストのトークン化 | テキストを数値表現(トークン)に変換する処理を担当します。 |
SFTConfig | trl | ファインチューニングの設定 | Supervised Fine-Tuning (SFT) のためのハイパーパラメータや学習設定を管理します。 |
SFTTrainer | trl | モデルの学習 | SFTを実行するためのトレーナークラス。 |
続いて、量子化とLoRAを利用して教師あり学習を実行する予定なので量子化とLoRAに関するライブラリも利用します。
- BitsAndBytesConfig
- LoraConfig
各ライブラリの概要を簡単にまとめておきます。
ライブラリ名 | 所属パッケージ | 主な用途 | 概要 |
---|---|---|---|
BitsAndBytesConfig | transformers | モデルの量子化設定 | モデルのメモリ使用量を削減するための量子化設定を管理するクラス。4ビット、8ビット量子化などの設定が可能で、学習や推論時のメモリ効率を向上させます。 |
LoraConfig | peft | LoRA手法の設定 | Low-Rank Adaptation (LoRA) のハイパーパラメータを管理するクラス。少ないパラメータ数で効率的なファインチューニングを実現するための設定を提供します。主な設定項目にはr (ランク)、alpha (スケーリング係数)、target_modules (適用するレイヤー)などがあります。 |
主要な手順
次のような手順でコードを書くことになります。(データの準備を除いて)
- AutoTokenizerでテキストをトークン化
- AutoModelForCausalLMで事前学習済みモデルを読み込み (ここまでの内容を下記にまとめておきます)
- SFTConfigで学習設定を定義 (3.3で紹介)
- SFTTrainerを使用してモデルをファインチューニング
モデルサイズが1.5Bと小さいので量子化しなくても大丈夫そうな気もするのですが、8bitで試してみました。
- ライブラリの読み込み
- モデルの指定
- トークナイザーの指定
- 量子化してモデル読み込み
適切に選んでもらえるauto設定で行けると思います。
# 1.ライブラリの読込
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer,BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
from trl import SFTTrainer, SFTConfig
# 2.モデルの指定
MAX_SEQ_LENGTH = 200 # 200トークンあれば足りるかな?
BASE_MODEL = "SakanaAI/TinySwallow-1.5B" # tinyswallow使うぞ
# 3.Tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.model_max_length = MAX_SEQ_LENGTH
# 4.8bit量子化してモデルを読み込む
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
quantization_config=bnb_config,
torch_dtype="auto",
trust_remote_code=True
)
続いてSFTするのですが、LoRAチューニングと呼ばれる手法を使います。ランクとαの値を適当に設定しておきます。target_modulesが学習対象のアダプターが追加されるネットワーク層となります1。
peft_config = LoraConfig(
r=100, # r=16 くらいだけど、今回は大きめ!!
lora_alpha=200,
target_modules = ['up_proj', 'gate_proj', 'o_proj', 'down_proj', 'v_proj', 'k_proj', 'q_proj'],
lora_dropout=0,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, peft_config)
# 勾配チェックポイント
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
SFTConfigやSFTTrainは3.3の学習の部分で紹介します。
3.2 データの整形(地味に重要な部分
)
データの内容になります。load_dataset
でwatashihakobashi/ogiriからデータを取得します。学習用と検証用(テスト用)に適当に分けます。
from datasets import load_dataset
from datasets import DatasetDict
# load_datasetでデータを取り込み
# train_test_splitで学習用と検証用に分割します
dataset = load_dataset("watashihakobashi/ogiri", split="train")
train_test_split = dataset.train_test_split(test_size=0.1, seed=55)
dataset_dict = DatasetDict({
'train': train_test_split['train'],
'test': train_test_split['test']
})
# 検証用にボケを削除したデータを用意
test_data = dataset_dict["test"].map(lambda example: {"boke": ""})
# 各データサイズを表示
print(f"学習用データ: {len(dataset_dict['train'])}")
print(f"検証用データ: {len(test_data)}")
#
#
# 学習用データ: 483
# 検証用データ: 54
読み込んだデータを確認してみると、
訓練用のサンプル
{'odai': '犬が飼い主にイラッ。何があった?',
'boke': 'エサを食べてるとき、横目でじろじろ見てくる'}
という形になります。お題とボケの組ですね。test_dataの方は、bokeの部分を利用しないので削除しました。
SFT用にテキストデータの書式を変更します。TinySwallow-1.5Bは chat_templateが定まっていたので、次の形で学習させると効率的 だと思われます2。
sakana_prompt = """<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
{お題目}<|im_end|>
<|im_start|>assistant
{ボケの内容}<|im_end|>
"""
odaiとbokeをそれぞれ該当する位置に挿入して訓練用のデータを作成します。
お題目の前に「以下の文章は大喜利のお題です。お題にあった面白いボケを作成してください。」という文言を挿入しておきます。
- 変数名
sakana_prompt
がちょっとあれですが、chat_templateに準拠した形へデータを変形 - EOS_TOKENに作文の終わりを示す<|im_end|>を割り当てます。これを文末に追加します。
-
formatting_prompts_func
関数で書式を変更。sakana_promptのカッコ {} 部分にodaiとbokeを挿入していきます。ここはもっと格好良い書き方があると思います。 -
dataset_dict["train"].map()
で実際にデータを再構成。load_datasetで読み込んだdatasetはmapメソッドを持っているのでこれを利用します。3と4でデータを再構成の基本パターンかな
# 1. chat_template
sakana_prompt = """<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
以下の文章は大喜利のお題です。お題にあった面白いボケを作成してください。
{}<|im_end|>
<|im_start|>assistant
{}
"""
# 2. 文末にeos_tokenの<|im_end|>を追加の準備
EOS_TOKEN = tokenizer.eos_token # <|im_end|>を挿入
# 3. 書式を変更する関数
def formatting_prompts_func(examples):
inputs = examples["odai"]
outputs = examples["boke"]
texts = []
for input, output in zip(inputs, outputs):
text = sakana_prompt.format(input, output)
if output:
text += EOS_TOKEN
texts.append(text)
return { "text" : texts, }
# 4. データを再構築
train_data = dataset_dict["train"].map(formatting_prompts_func, batched = True)
# 5. text以外のキーを削除したほうがよいらしい。他のキーがあるとバグるときがあるみたい。
train_data = train_data.remove_columns(["odai", "boke"])
train_dataが設定した形になっているのか確認してみます。print文で中にを確認すると下記の形になります。
print(train_data[0]["text"])
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
以下の文章は大喜利のお題です。お題にあった面白いボケを作成してください。
犬が飼い主にイラッ。何があった?<|im_end|>
<|im_start|>assistant
エサを食べてるとき、横目でじろじろ見てくる
<|im_end|>
3.3 学習
いよいよ学習の部分となります。SFTConfigとSFTTrainerの2つを設定して、trainer.train()でOKです。重要なのはdataset_text_field="text"
の部分。train_dataのどの部分を学習するのか明示的に書いておく必要があるようです。あとの数値は学習結果を見ながら適当に決めました。細かいオプションの解説は他の記事や書籍に紹介されているのでここでは省略してしまいます。SFTConfigに学習の主要な設定を記述して、SFTTrainerには、学習対象のモデル、データセット名、設定の3つを指定します。
training_args = SFTConfig(
dataset_text_field="text",
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate = 2e-5,
optim="adamw_torch_fused",
lr_scheduler_type="cosine",
output_dir="./outputs",
max_seq_length=512,
fp16=True, # fp16
logging_steps=2,
# save_steps=100,
save_only_model=True,
auto_find_batch_size=True
)
trainer = SFTTrainer(
model,
train_dataset=train_data,
args=training_args,
)
model.config.use_cache = False
trainer.train()
これで学習が開始&終了します。データ数が少ないので結構あっという間ですよ
損失が1.5くらいだとあまり面白くない文章が生成されます。
3.4 生成
最後にテスト用のデータを使って検証します。面白いボケになるのか!?出力サンプルを最初に紹介します。
出力例
お題8の「電車」一瞬つまらないぞ〜と思いますが、これは2050年だから?これで面白いのかも
生成コード
徐々に文字が表示されるタイプが個人的に好きなのでTextStreamerを使います。リード文(msg)を追加したお題を指示文として与えます。あとはmodel.generate()
で文を生成するだけです。生成する部分を関数にしておくときれいに見えると思いますが割愛
model.config.use_cache = True
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
msg = "以下の文章は大喜利のお題です。お題にあった面白いボケを作成してください。\n"
model.eval()
for num, odai in enumerate(test_data["odai"]):
user_message = msg+odai
prompt = [{"role": "user", "content": user_message}]
input_ids = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_tensors="pt").to(model.device)
print(f"お題{num+1}: {odai}")
with torch.no_grad():
result = model.generate(
input_ids,
max_new_tokens=100,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
streamer=streamer
)
del input_ids
torch.cuda.empty_cache()
print("-"*50)
apply_chat_templateされたinput_idsがどのような形になっているのか中身を確認してみました。"猿が日向ぼっこ.気持ちは?"をお題にして、input_idsの中身を確認すると...
msg = "以下の文章は大喜利のお題です。お題にあった面白いボケを作成してください。\n"
user_message = msg+"猿が日向ぼっこ.気持ちは?"
prompt = [{"role": "user", "content": user_message}]
input_ids = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_tensors="pt").to(model.device)
# テンソルの場合、最初に.tolist()を使用
decoded_text = tokenizer.decode(input_ids[0].tolist())
print(decoded_text)
decode_textの値はちゃんと指定した書式になっています。というか指定した書式になっていない場合はやり直しですが
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
以下の文章は大喜利のお題です。お題にあった面白いボケを作成してください。
猿が日向ぼっこ.気持ちは?<|im_end|>
<|im_start|>assistant
model.generate()
のオプションにeos_token_id=tokenizer.eos_token_id
を指定することで、max_new_tokens=100
トークン数以内でも<|im_end|>
で終了となります。
4. 比較
TinySwallow-1.5B-Instructと今回学習したモデルの出力結果を比較してみました。指示文は同じ形式にしています。TinySwallow-1.5B-Instructは「以下の文章は大喜利のお題です。お題にあった面白いボケを作成してください。」で学習されていない分、不利なきもします。TinySwallow-1.5B-Instructは会話形式でボケを連発しているのがユニークmarkdownの形式で出力されています。**や#がついています。100トークンまでの制限で出力しているので文章が途切れています。
番号 | お題 | TinySwallow-1.5B+ogiriデータで学習 | TinySwallow-1.5B-Instruct |
---|---|---|---|
1番 | 「マグカップ」がこんなシーンで役に立つ!? | マグカップの底に穴を開けて、水を飲む | ## マグカップの意外な活用法\n状況: 友子と健太郎はカフェでコーヒーを飲んでいた。\nボケ:\n*こう、マグカップって、実は結構便利なんだよね!\n例えば、友達からもらったマグカップがあったとして、\n「あー、これ誰かからの贈り物だね」\nって思ってたんだけど、\n「でも、このマグカップ、ちょっと…」\nってなる前に、すぐに |
2番 | 「この漫画、単行本の表紙がかっこいい。」どんな表紙なの? | 全部同じ絵柄で描いてるから | ## ボケ案:\n1.\n> この漫画、単行本の表紙がかっこいい! でも、ちょっとだけ気になるのが… 中身もかっこいいのかな?\n2.\n> この漫画、単行本の表紙がかっこいい! でも、その裏側には… 何があるんだろう?\n3.**\n> この漫画、単行本の表紙がかっこいい! でも、その中に |
3番 | 飲み会での主役が参加者にイラッ。何があった? | 「この料理、まずいなぁ…」 | ## 飲み会での主役が参加者にイラッ。\n状況: 主役のAさんが、いつも通りの豪快な料理を前に「これ、最高だぜ!」と叫び始めました。しかし、隣に座っていたBさんは、「ちょっと待ってください」と冷静に言いました。「いや、この味付け、もう少し塩辛い方がいいんじゃないですか?」\nボケ:\n*ー! Aさん、また自分の |
4番 | イベントスタッフが参加者にイラッ。何があった? | 「このクイズ、簡単すぎない?」 | ## イベントスタッフの「イラッ」シーン\n状況: あるイベントで、スタッフが参加者の指示を待っている時。\nスタッフA: 「よし、準備完了!」\n参加者B: 「えっと…ちょっと待ってください。あの、〇〇さんと話したいんですけど…」\nスタッフA: (イライラ) **「あー、またか…。〇〇さんですか?今すぐ来てください |
5番 | 「この靴下、ダメだなぁ…」なぜ? | こんなに色が濃いと履くのが恥ずかしい | ## 大喜利のネタ:靴下\nお題: この靴下、ダメだなぁ…\nボケ:\n* 「この靴下、ダメだなぁ…」\n *「だって、履くと足が蒸れるんだもん!」\n *「いや、それよりも、色が薄いから目立たないのが問題なんだよ…」\n ***「うーん、確かにそうかもね。でも、一番 |
6番 | スーパーの試食コーナーで怒られました。なぜ? | 「この味、ちょっとまずいな…」 | ## スーパーの試食コーナーでのボケ\nA: 「このサラダ、ちょっと酸っぱいかな…」\nB: 「えっ!? そんなこと言わないでよ! 私たちはお客様のためにここにいるんだから!」\nA: 「あ、すみません、失礼しました。でも、確かに少し酸味があるかもしれませんね。でも、これは新鮮な野菜をたっぷり使っているので |
7番 | アイスクリームの新しいフレーバーの名前は? | 「ミルク」 | ## アイスクリーム新フレーバー:「夢幻の果実」\nボケ: 「最近、夢の中で食べたような不思議な味のアイスクリームが流行ってるんです!\nツッコミ: 「それは『夢幻の果実』って言うらしいですよ!」 |
8番 | 2050年の交通手段で一番多く使われるものは何? | 電車 | ## 2020年交通手段ランキング\n1位:自動運転バス\n - 理由: 都市部の渋滞緩和と環境負荷低減に貢献。\n2位:ドローンタクシー\n - 理由: 短時間での移動や高齢者・障害者の利用を可能にする。\n3位:電気自動車(EV)\n - 理由: |
大喜利のデータだけで指示チューニング行ったのですが、他の質問だとどうなるのか?という汎化性能も確認してみました。また別の記事で結果を紹介したいと思います
注
-
target_modulesってどうやって見つけるの?については、npakaさんの記事を参考にしてください。 ↩
-
他のパターンでも試してみましたがなかなか学習が進まないのと、出力結果が芳しくないようです。テンプレートに従うのが良さそうでした。 ↩