2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Llama2をFinetuneして友達をクローンした

Last updated at Posted at 2024-06-28

※この記事は2023年12月に執筆されたものです。

目的

特定の友人とのLINEトーク履歴を教師データとしてLLMをファインチューニングし、友人のクローンBotを作成してみたい。
今回は、長年の付き合いで他愛のないラインのやりとりがコンスタントに続いている友人(以下friend氏としている)とのトーク履歴を使用する。
friend氏っぽさをLlama2が身につけることができたら成功としよう。
※ 本人にデータ使用の許可は得ています。

まずは結果

とりあえずは成功といえるレベルに友人(そしてfriend氏とチャットしている自分)が再現できた。
十分なトーク履歴さえあれば元カレ元カノ亡きあの人、誰でもチャットで蘇らせて延々とおしゃべりできるようになる気がする。

作業記録

Finetune前に、まずはfew-shot性能を見てみる

今回finetuneを試みる日本語llama2モデルelyza/ELYZA-japanese-Llama-2-7b-instructのfew-shot性能を確認。

from transformers import AutoTokenizer, AutoConfig
import transformers
import torch

model_name= "elyza/ELYZA-japanese-Llama-2-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
config.pretraining_tp = 1

model=transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",  # Passing device_map = 0 means putting the whole model on GPU 0. Other inputs could be cpu, cuda:1, etc. Setting device_map = auto will let accelerate compute the most optimized device_map automatically.
    torch_dtype=torch.float16)
if torch.cuda.is_available():
    model = model.to("cuda")

#推論
model.eval()

text = """
'me\t修論飽きておじゃる丸とか見てた\n'
"""
prompt = f"""
### Chat sequence:
'friend\t明日どこがいい?\n',
'me\t一応古着屋いくって話だったよね\n',
'friend\t[スタンプ]\n',
'me\t原宿でよい?\n',
'friend\tおっけー\n',
'me\tなにさがすの?\n',

### Response:
'friend\tcoatとかズボン\n',
'friend\tシャツでもいいのがあれば\n',
<end>

### Chat sequence:
'friend\tおきた\n',
'me\t[スタンプ]\n',
'me\tくさい\n',

### Response:
'friend\t臭い?\n',
'friend\t\n\n']
<end>


### Chat sequence:
{text}
### Response:

"""
inputs = tokenizer.encode(prompt, return_tensors="pt")

with torch.no_grad():
    outputs = model.generate(
    inputs.to(model.device),
    temperature=0.6,
    max_new_tokens=256)
print(tokenizer.decode(outputs.tolist()[0][inputs.size(1) :], skip_special_tokens=True))

出力結果

'friend おじゃる丸いいよ
'
### Chat sequence:
'friend 明日は休み?
'
### Response:
'me 休み
'
### Chat sequence:
'friend 今日はどこ行く?
'
### Response:
'me 古着屋
'
### Chat sequence:
'friend 行っちゃおう
'
### Response:
'me 行っちゃう
'
### Chat sequence:
'friend 明日は何しよう
'
### Response:
'me 映画
'
### Chat sequence:
'friend いいね
'
### Response:
'me

延々とチャットの続きを生成してくれた。
Few-shot例だけでも我々のチャットの空気感を汲み取ってくれていて凄い。が、やっぱなんか違うというか、「既視感」は無い内容に思った。

教師データの作成

LINEのトーク履歴は、txt形式でダウンロードすることができる。
txt形式なので、スタンプや画像は全て [スタンプ] [写真] といったテキストに置換された状態になっている。

例:friendとのトーク.txt
2019/01/01(火)
18:16	friend	あけましておめでとう
19:34	me	おめでとうー

2019/01/02(水)
10:50	friend	宝くじ三百円すら当たらんかった
12:18	me	自分で買ってないからセーフ
12:18	me	[スタンプ]
12:18	me	[スタンプ]
12:34	friend	[写真]
12:35	friend	300円すら1/10で買う気にならんw

トーク履歴は約10年分あるのだが、全て使用するのも多すぎるので、今回は2019.1~2021.3までの27ヶ月分のみを使用。

前処理

以下の手順で前処理をし、教師データを作成した。

  1. 日付ごとに分けてリストに格納。
  2. 通話がある日のデータは除く(通話しながらのチャットは脈絡がなくなりがちなので)
  3. タイムスタンプは除去し、ユーザー名と発話内容はタブで区切る
  4. やり取りが me で終わる日のデータは除く(最後のfriend氏による発話を正解データとして学習させたいので)
  5. URLを全て [URL] に置換する。
  6. 同じユーザからの [動画] [写真] [スタンプ] [URL] の連投は、それぞれ1つにまとめる
with open("/content/drive/MyDrive/friendtalk_2019Jan_2021March.txt") as f:
  text = f.read()

# 日付ごとのデータに分けてリストに格納。空データは削除
chats = re.split(r'\d{4}/\d{2}/\d{2}\([月火水木金土日]\)\n', text)
chats = [c for c in chats if c!=""]

# 通話履歴が含まれる日のデータは削除
chats_without_call = [chat for chat in chats if
                      "☎ 通話をキャンセルしました" not in chat and
                      "☎ 通話時間" not in chat and
                      "☎ 通話に応答がありませんでした" not in chat]

# 各会話をクリーニングする関数
def merge_continuous_values(list):
  """リストの中で、同じ値が連続している場合は一つにまとめる。

  Args:
    list: 値のリスト

  Returns:
    まとめられたリスト
  """
  merged_list = []
  for i, value in enumerate(list):
    if i == 0:
      merged_list.append(value)
    elif value == merged_list[-1]:
      pass
    else:
      merged_list.append(value)

  return merged_list

url_re = re.compile('https?://[^\s]+')

cleaned_chats = []
for chat in chats_without_call:
  msgs = re.split(r'\d{2}:\d{2}\t', chat)
  new_msgs = []
  while not msgs[-1].startswith("friend") and len(msgs)>1:
    try:
      msgs.pop()
    except IndexError:
      break
  for m in msgs:
    if m != "":
      m = url_re.sub('[URL]', m) # URLは[URL]に置換
      new_msgs.append(m)

#同じユーザからの動画 写真 スタンプ URLの連投は1つにする
cleaned_chats.append(merge_continuous_values(new_msgs)) 
cleaned_chats = ["".join(c) for c in cleaned_chats if len(c) > 1]

cleaned_chats[0]の内容はこんな感じ

friend	宝くじ三百円すら当たらんかった
me	自分で買ってないからセーフ
me	[スタンプ]
friend	[写真]
friend	300円すら1/10で買う気にならんw

これを、Llama向けのプロンプトフォーマットにしてcsvファイルにする。
ちなみに、この時点でサンプル数は614件(614日分のチャット履歴)。

import csv
cleaned_chats_for_instruct_model = []

for chat in cleaned_chats:
 msgs = chat.split("\n")
 new_msgs = []
 for m in msgs:
   if len(m)!= 0:
     new_msgs.append(m)

 context = "\n".join(new_msgs[:-1])
 completion = new_msgs[-1][7:] # 最後の発話(friendのもの)を正解とする

 prompt = f"""<s>[INST] Please impersonate friend to complete the following chat.\n\nChat:\n{context}\nfriend\t [/INST] {completion} </s>"""

 cleaned_chats_for_instruct_model.append(prompt)

def create_csv_from_list(data_list, file_path):
   # CSVファイルのカラム名
   fieldnames = ["text"]

   # CSVファイルを書き込む
   with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
       writer = csv.writer(csvfile)

       # カラム名を書き込む
       writer.writerow(fieldnames)

       # データをリストとしてCSVファイルに書き込む
       for item in data_list:
           writer.writerow([item])
           
create_csv_from_list(cleaned_chats_for_instruct_model, "/data/tagged_chattalk_data_for_instruct_model.csv" )

訓練

train.ipynb
import pandas as pd
df=pd.read_csv("/data/tagged_chattalk_data_for_instruct_model.csv", index_col=0)

# Split
from datasets import load_dataset, DatasetDict,Dataset
from sklearn.model_selection import train_test_split
import pandas as pd

train, buffer = train_test_split(df, test_size = 0.1,random_state=1)
test, val = train_test_split(buffer, test_size = 0.5,random_state=1)
dataset = DatasetDict({
    'train': Dataset.from_pandas(train),
    'test': Dataset.from_pandas(test),
    'valid': Dataset.from_pandas(val)})

import torch
torch.cuda.empty_cache()

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer


model_name = "elyza/ELYZA-japanese-Llama-2-7b-instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = AutoModelForCausalLM.from_pretrained(
    model_name, 
    quantization_config=bnb_config, 
    
    trust_remote_code=True
)
model_4bit.config.use_cache = False
model = model_4bit 

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.unk_token

# This is the fix for fp16 training
tokenizer.padding_side = "right"

import torch
from peft import LoraConfig


# LoRAパラメータの準備
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj"],
    modules_to_save = ["lm_head", "embed_tokens"] # when additional tokens
)

# パラメータの設定
run_name="impersonater_elyza_llama2_7b_instruct"
output_dir  =  "output/impersonater_elyza_llama2_7b_instruct"
do_eval=True
evaluation_strategy="steps"
eval_steps = 1
per_device_train_batch_size = 8
per_device_eval_batch_size=1
gradient_accumulation_steps = 4
optim = "paged_adamw_32bit"
save_steps = 5
logging_steps = 1
learning_rate = 5e-4
max_grad_norm = 0.3
# max_steps = 3000
max_steps = 600
# warmup_ratio = 0.03
warmup_ratio = 0
# lr_scheduler_type = "constant"
# lr_scheduler_type = "linear"
lr_scheduler_type = "cosine"
# lr_scheduler_type = "cosine_with_restarts"
# num_cycles=8
gradient_checkpointing = True
report_to="wandb"
fp16 = True
bf16 = False
weight_decay = 0.001
max_seq_length = None #1000
metric_for_best_model="eval_loss"
load_best_model_at_end=True
packing=False
group_by_length = True

import wandb
wandb.login()

%env WANDB_LOG_MODEL='end'
%env WANDB_PROJECT="impersonater_elyza_llama2_7b_instruct"
%env WANDB_SILENT=True

from transformers import TrainingArguments
from trl import SFTTrainer

training_arguments = TrainingArguments(
    do_eval=do_eval,
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    evaluation_strategy=evaluation_strategy,
    eval_steps = eval_steps,
    learning_rate=learning_rate,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to=report_to,
    metric_for_best_model= metric_for_best_model,
    load_best_model_at_end=load_best_model_at_end,
    run_name = run_name,
    save_total_limit=3   # これを指定しないと結局save_stepsごとに作成されるCheckpointが全部保存されてしまう
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["valid"],
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=packing,
)


# trainer.train(resume_from_checkpoint=True)
trainer.train()
trainer.model.save_pretrained(output_dir)

import shutil
print(shutil.make_archive('impersonater_llama2_7b_elyza_tagged_checkpoint30', 'zip', root_dir='output/impersonater_elyza_llama2_7b_tagged/checkpoint-30'))

学習曲線

W&B Chart 2024_4_2 21_05_37.png
W&B Chart 2024_4_2 21_05_19.png

Run summary:
eval/loss 1.20005
eval/runtime 5.5884
eval/samples_per_second 5.547
eval/steps_per_second 0.716
train/epoch 2.78
train/global_step 48
train/learning_rate 0.00049
train/loss 1.0227

訓練したモデルで推論

infer.ipynb
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer

model_name = "elyza/ELYZA-japanese-Llama-2-7b-instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = AutoModelForCausalLM.from_pretrained(
    model_name, 
    quantization_config=bnb_config, 
    trust_remote_code=True
)
model_4bit.config.use_cache = False
model = model_4bit 

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token =  tokenizer.unk_token 

# This is the fix for fp16 training
tokenizer.padding_side = "right"


checkpoint = "output/impersonater_elyza_llama2_7b_instruct/checkpoint-30"

from peft import PeftModel
import torch
trainedmodel = PeftModel.from_pretrained(
    model,
    checkpoint,
    torch_dtype=torch.float16,   
     device_map={'':0}
)

if torch.cuda.is_available():
    trainedmodel = trainedmodel.to("cuda")


def gen(text):
    token_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
    token_ids.to("cuda")

    with torch.no_grad():
        output_ids = trainedmodel.generate(
            inputs = token_ids,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=128,
            temperature = 0.6,
        )
    output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True)
    print(output)


text = "me\t修論飽きておじゃる丸とか見てた\nfriend\t"
gen(text)

学習済みモデルで、先ほどと同じチャットの続きを生成させてみる。

まずはdo_sample=Falseとして greedy decodingの結果を見てみる。

おじゃる丸すき
me	[スタンプ]
me	[URL]
me	これすき
me	[URL]
me	これすき
me	[スタンプ]
me	[スタンプ]
me	[スタンプ]
me	[スタンプ]

学習の結果、「おじゃる丸いいよ」が「おじゃる丸すき」になった。
些細な違いだが、これはすごくfriend氏"ぽい"。
ていうかその後に続けて生成されるやり取りも、[スタンプ]と[URL]の繰り返しになっているものの、非常に我々のトーク履歴っぽくなっている。

generateの引数にno_repeat_ngram_size=8を設定し、繰り返しを回避してみると

おじゃる丸すき
me	[スタンプ]
me	おじゃる丸すごい
friend	すごい
me	毎日見てるとおじゃる丸の中だけ進んでる感がすごい
me	[スタンプ]

若干無意味かつオウム返し的やり取りがみられるものの、個人的に「既視感」のあるチャットが繰り広げられている。なんならオウム返しはたまにやってるので、LLMが学習した結果なのかどうかすらわからない。

特に自分の口癖である「~感がすごい」が再現されていて背筋が冷えた。
(ちなみに「~感がすごい」は、教師データに3回登場していた。)

次の例。

text = "me\tプライムビデオでコララインとボタンの魔女って映画みたんだけどめっちゃおもしろかった\nfriend\t"

>>
すごい
me	幼女のキャラクターがめっちゃかわいい
friend	すごい

自分、めちゃくちゃ「幼女のキャラクターがめっちゃかわいい」とか言ってそうでまた背筋が凍る(別にこんなセリフは教師データには無かった)。

さっきから自分の再現度に驚いているが、肝心のfriend氏は「すごい」しか言ってなくね?
教師データを確認すると、friend氏による「すごい」は30回も登場したので確かに口癖なのだが(自分による「すごい」は2回のみ。)、ちょっと多発しすぎ感。

色々遊んでみる

ではそろそろ、do_sample=Trueにしてモデルを暴れさせてみるぞ。

temperature = 0.6

text = "me\t修論飽きておじゃる丸とか見てた\nfriend\t"
gen(text)

>>
おじゃマル
me 謎の文字入力してて草
text = "me\tプライムビデオでコララインとボタンの魔女って映画みたんだけどめっちゃおもしろかった\nfriend\t"

>>
me嬉しそう
friend	最近見れる映画少ない
me	最近できた映画館いって見ようかな
me	毎日暑い

ちゃんと私のあだ名も呼んでくれた。

temperature = 1.0

text = "me\t大谷とプリキュアとドンキーコングで誰が最強だと思う\nfriend\t"
gen(text)

>>
大谷が一番強い
me 意外とわからないな
friend 理論上はドンキーコングが最も強い
text = "me\tビール飲みてー\nfriend\t"
gen(text)

>>
[スタンプ]
me	ぬわああああああああああああああああああああああああああああん

恥ずかしくなってくるレベルに自分と友人のチャットが再現できていて全貌をここに記録するのは憚られた()

本当に友人とLINEしてるみたいで楽しすぎて時間が溶けたので、gradioでアプリ立てた

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer

model_name = "elyza/ELYZA-japanese-Llama-2-7b-instruct"
# model_name = "elyza/ELYZA-japanese-Llama-2-7b"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = AutoModelForCausalLM.from_pretrained(
    model_name, 
    quantization_config=bnb_config, 
    trust_remote_code=True
)
model_4bit.config.use_cache = False
model = model_4bit 


# tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "right"


# Add special tokens
# special_tokens = ["[R_START]", "[R_END]"]

tokenizer = AutoTokenizer.from_pretrained(model_name)

# tokenizer.add_tokens(special_tokens, special_tokens=True)
# this will make new learnable parameters for specialized tokens
# model.resize_token_embeddings(len(tokenizer))

# tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token =  tokenizer.unk_token 

# This is the fix for fp16 training
tokenizer.padding_side = "right"

checkpoint = "output/impersonater_elyza_llama2_7b_tagged/checkpoint-30"

from peft import PeftModel
import torch

trainedmodel = PeftModel.from_pretrained(
    model,
    checkpoint,
    torch_dtype=torch.float16,   
     device_map={'':0}
)

if torch.cuda.is_available():
    trainedmodel = trainedmodel.to("cuda")

def gen(text):
    # GreedySearch.
#     prompt = f"<s>[INST] Please impersonate friend to complete the following chat.\n\nChat:\n{text} [/INST]"
#     token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    token_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
    token_ids.to("cuda")


    with torch.no_grad():
        output_ids = trainedmodel.generate(
        inputs=token_ids,
        do_sample=True,
        pad_token_id = tokenizer.eos_token_id,
        eos_token_id = tokenizer.eos_token_id,
        max_new_tokens=128,
        top_k=100,
        top_p=0.9,
        repetition_penalty=1.2
        )
        output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=False)
    
    
    return output

import gradio as gr


history = []

chathistory="" # これをsessionstateに格納する

def add_text(history, text, state):
    history = history + [(text, None)]
    return history, ""


def infer(question, state):
    state["chathistory"] += f"Me\t{question}\nfriend\t"
    response = gen(state["chathistory"])
    continuous_response=[]
    for r in response.split("\n")[1:]:
        if r.startswith("friend"):
            continuous_response.append(r.split("\t")[1])
        else:
            break
    return response.split("\n")[0], continuous_response

def bot(history, state):    
    res, continuous_response = infer(history[-1][0], state)
    if len(continuous_response)!=0:
        for r in continuous_response:
            res = res + "\n" + r
    
    state["chathistory"] += res
    history[-1][1] = res
    
#     if len(continuous_response)!=0:
#         for r in continuous_response:
#             history.append(["", r])


    
#     return [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
    return history, state




with gr.Blocks() as demo:
    state = gr.State({
        "chathistory": ""
    })
    chatbot = gr.Chatbot([], elem_id="chatbot")
    question = gr.Textbox(label="Question", placeholder="Type your message and hit Enter ")
    submit_btn = gr.Button("送信")
    clear_engine_btn = gr.Button("エンジンをクリア")
    
    submit_btn.click(
        add_text,
        inputs=[chatbot, question],
        outputs=[chatbot, question],
        queue=True
    ).then(
        bot,
        inputs=[chatbot, state],
        outputs=[chatbot, state]
    )

demo.queue(concurrency_count=3)

demo.launch(share=True, debug=True)
2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?