概要
前回はTanuki-8Bを利用して1ターンの質問応答を作成しました.
次の項目を満たすような修正を追加したいと思います.
- 会話の履歴を引き継いで会話を継続できる.
- exitで会話終了できるようにする.
- 会話履歴の初期化を導入して,入力トークン数をコントロールできるようにする.
1. 会話ループ
返答生成の関数とwhile文によって,会話のループが簡単に実装できます.このコードでは,会話の履歴を引きつかず1ターン毎に会話が途切れてしまいます.promptの部分をうまく修正して【基本形】のコードを修正していきます.
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
tokenizer = AutoTokenizer.from_pretrained("weblab-GENIAC/Tanuki-8B-dpo-v1.0")
model = AutoModelForCausalLM.from_pretrained("weblab-GENIAC/Tanuki-8B-dpo-v1.0")
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 返答生成の関数
def generate_response(user_message):
prompt = [{"role": "system", "content": "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。"},
{"role": "user", "content": user_message}]
input_ids = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_tensors="pt").to(model.device)
model.eval()
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=500, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=False, streamer=streamer)
del input_ids
torch.cuda.empty_cache()
# 会話ループ
while True:
user_message = input("User: ")
generate_response(user_message)
2. 会話履歴の利用
会話の履歴を利用して意味のある会話が続くようにしたい!簡単に思いつく方法として,過去の会話に新しい話題を追加する形で情報を付加していく方法が考えられます.会話の履歴リストのようなものを作成し,対話が進むごとに利用されたデータをリストに追加していくという形となります.
具体例
最初(1期目)に,ユーザーが「こんにちは!」と入力した状態を考えます.Tanuki-8Bはプロンプトが日本語Aplaca形式に従っているので,入力されるデータは次のような形になっています.
conversation_history = [
{"role": "system", "content": "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。"},
{"role": "user", "content": "こんにちは!"}
]
返答が,「こんにちは!何かお手伝いできることはありますか?」であるとします.この返答を追加して,次の会話に向けたプロンプトを
conversation_history = [
{'role': 'system', 'content': '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。'},
{'role': 'user', 'content': 'こんにちは!'},
{'role': 'assistant', 'content': '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n\n
### 指示:\nこんにちは!\n\n
### 応答:\nこんにちは!何かお手伝いできることはありますか?'}
]
となるようにします.1期目の対話情報を入力データとして持ちつつ,2期目の対話が開始されることになります.この段階で,次のユーザーメッセージを入力できるようにします.ユーザーが「今日は暖かいですね.」と入力したとします.
conversation_history = [
{'role': 'system', 'content': '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。'},
{'role': 'user', 'content': 'こんにちは!'},
{'role': 'assistant', 'content': '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n\n
### 指示:\nこんにちは!\n\n
### 応答:\nこんにちは!何かお手伝いできることはありますか?'},
{'role': 'user', 'content': '今日は暖かいですね.'}
]
この converstation_history をトークナイズしてID化すれば,2期目の入力データが完成します.トークナイズしてID化する関数がapply_chat_template() です.
3. トークナイズ
apply_chat_template()
は,Hugging Face Transformersライブラリのトークナイザーが提供するメソッドで,入力されたテキストをトークナイズ,IDテンソルに変換する働きを持っています.
先程の会話履歴converstation_history
を apply_chat_template()
を使って次のコードで整形してみます.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("weblab-GENIAC/Tanuki-8B-dpo-v1.0")
# 会話履歴をトークナイザーのテンプレートに適用
input_ids = tokenizer.apply_chat_template(
conversation_history,
add_generation_prompt=True,
return_tensors="pt"
)
print(input_ids)
例1の状態(ユーザーが「こんにちは!」まで入力した状態)をID化したものが次の形になります.
tensor([[ 1, 272, 1411, 276, 269, 12311, 273, 680, 530, 297,
3994, 298, 270, 381, 2023, 273, 950, 1219, 274, 4074,
354, 1045, 2850, 273, 1530, 3314, 270, 8, 8, 48979,
272, 3994, 327, 8, 3124, 309, 8, 8, 48979, 272,
1045, 2850, 327, 8]])
例2の状態(ユーザーが「今日は暖かいですね.」まで入力した状態)をID化したものが次の形になります.入力された文がすべてID化したテンソルになっていることが確認できます.長いID化したテンソルの続きを生成することで,1期目の対話情報を考慮した2期目の返答と解釈します.
tensor([[ 1, 272, 1411, 276, 269, 12311, 273, 680, 530, 297,
3994, 298, 270, 381, 2023, 273, 950, 1219, 274, 4074,
354, 1045, 2850, 273, 1530, 3314, 270, 8, 8, 48979,
272, 3994, 327, 8, 3124, 309, 8, 8, 48979, 272,
1045, 2850, 327, 8, 1411, 276, 269, 12311, 273, 680,
530, 297, 3994, 298, 270, 381, 2023, 273, 950, 1219,
274, 4074, 354, 1045, 2850, 273, 1530, 3314, 270, 8,
8, 48979, 272, 3994, 327, 8, 3124, 309, 8, 8,
48979, 272, 1045, 2850, 327, 8, 3124, 309, 314, 1785,
1498, 329, 806, 314, 3462, 282, 3925, 2287, 930, 1881,
270, 1320, 1335, 302, 394, 739, 446, 1770, 301, 7883,
316, 329, 2, 272, 8, 8, 48979, 272, 3994, 327,
8, 620, 276, 4790, 2268, 284, 8, 8, 48979, 272,
1045, 2850, 327, 8]])
-
apply_chat_template()
の第一引数は会話履歴 (リスト形式) 1 -
add_generation_prompt=True
は応答生成のためのプロンプトを追加するかどうかの指定 -
return_tensors="pt"
はPyTorchのテンソル
4. 会話のループ
converstation_history
には情報が蓄積され,会話が進むたびに入力されるトークンが肥大化していきます.会話履歴の初期化や会話の終了を追加します.
- ユーザーがCLEARと入力するとconverstation_historyを初期化する仕組み
- exitと入力して会話ループから抜けるような仕組み
2点を追加したいと思います.
初期化のコードは,初期のプロンプトをつくり,初期化する関数を用意します.あとは,if文を利用してCLEARでconverstation_historyを初期化するようにします.
# 初期システムプロンプト
INITIAL_SYSTEM_PROMPT = {"role": "system", "content": "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。"}
# 会話履歴を初期化する関数
def initialize_conversation():
return [INITIAL_SYSTEM_PROMPT]
会話ループから抜ける仕組みはwhile文中でif文を利用してexitでwhileループから抜けるようにします.
# 会話ループ
while True:
user_input = input("User: ")
# 入力した文字列が exit なら while ループから抜ける
if user_input == "exit":
break
# こちらがわが通常の返答
response, clear_message = generate_response(user_input)
if clear_message:
print("System:", clear_message)
1の初期化と2のexitを【基本形】のコードに加えて完成形になります.
コード全体
高速化のために8bit量子化モデルを利用します.このあたりは前回を参考にしてください.
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, BitsAndBytesConfig
import torch
# 8bit量子化
base_model = "weblab-GENIAC/Tanuki-8B-dpo-v1.0"
tokenizer = AutoTokenizer.from_pretrained(base_model)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.float16
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 初期システムプロンプト
INITIAL_SYSTEM_PROMPT = {"role": "system", "content": "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。"}
# 会話履歴を初期化する関数
def initialize_conversation():
return [INITIAL_SYSTEM_PROMPT]
# 会話履歴を初期化
conversation_history = initialize_conversation()
# 返答生成の関数
# apply_chat_template()関数を使って履歴を整形
# model.generate()関数で返答生成
def generate_response(user_input):
global conversation_history
# CLEARコマンドの処理
# CLEARの時,"会話履歴をクリアしました!"と表示させる
if user_input == "CLEAR":
conversation_history = initialize_conversation()
return None, "会話履歴をクリアしました!"
# conversation_historyに会話履歴を追加
# apply_chat_template()によって generate()関数に入力するデータに変換
conversation_history.append({"role": "user", "content": user_input})
input_ids = tokenizer.apply_chat_template(conversation_history, add_generation_prompt=True, return_tensors="pt").to(model.device)
# 最大500文字まで生成
model.eval()
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=500,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
streamer=streamer
)
# 生成されたテキストをデコード
# converstation_historyに会話履歴を追加
response = tokenizer.decode(output[0], skip_special_tokens=True)
conversation_history.append({"role": "assistant", "content": response})
del input_ids
torch.cuda.empty_cache()
return response, None
# 会話ループ
while True:
user_input = input("User: ")
if user_input == "exit":
break
response, clear_message = generate_response(user_input)
if clear_message:
print("System:", clear_message)
チャットのサンプル
User: コーヒー豆の産地について一つ紹介してください.
コーヒー豆の産地として特に有名なのはエチオピアです。エチオピアは「コーヒー発祥の地」として知られ、その豊かな風味と独特の香りで世界中のコーヒー愛好家に愛されています。特にシダモ地方やシダマラ地方は高品質なアラビカ種の豆を生産しており、そのクリアでフルーティーな味わいが特徴です。また、エチオピアのコーヒーは標高差が大きく、異なる気候条件が豆の風味に多様性をもたらしています。このため、エチオピア産のコーヒーはシングルオリジンとして非常に人気があります。
User: 器具について表形式で要点を解説してください.
もちろんです。以下にコーヒー豆を焙煎するための一般的な器具について、表形式で要点を解説します。
| 器具名 | 役割 | 主な機能 | 使用方法 | 注意点 |
|--------|------|----------|----------|--------|
| コーヒーミル | コーヒー豆を粉砕する機械 | 豆を細かく挽くことで抽出効率を向上させる | 豆を挽く際に均一な粒度を保つことが重要 | 手動と電動があり、手挽きと電動の違いがある |
| ドリッパー | コーヒーを抽出するためのフィルター付きカップ | コーヒー粉を均等に配置し、お湯を注ぐことで抽出を行う | ペーパーフィルターや金属フィルターなど種類が豊富 | サイズや形状によって抽出速度や風味が変わる |
| ケトル | お湯を沸かすための器具 | 適切な温度のお湯を供給するために使用 | ステンレス製やセラミック製が一般的 | 沸騰時間や温度管理が重要 |
| フレンチプレス | コーヒー粉とお湯を一緒にプレスして抽出する器具 | 金属製のフィルターを使用し、コーヒー粉とお湯を一緒に浸す | 均一な抽出が可能で、特に濃いコーヒーが好きな人に適している | 洗浄がやや手間だが、一度セットすれば簡単 |
| エスプレッソマシン | 高圧で短時間でコーヒーを抽出する装置 | 高圧の水と専用のグラインダーを使用して微細なコーヒー粉から抽出 | 操作が複雑で、専用の技術が必要 | 高品質なコーヒーを短時間で楽しめるが、初期投資が高い |
これらの器具は、それぞれ異なる方法でコーヒーを抽出するため、目的や好みに応じて使い分けることが重要です。例えば、手軽にコーヒーを楽しみたい場合はフレンチプレスやドリップコーヒーが適しており、特別な一杯を求めるならエスプレッソマシンが最適です。
User: CLEAR
System: 会話履歴をクリアしました!
2番目の入力.「器具について」の器具がコーヒーに関連する器具であることが理解されています.履歴が利用されない場合は,コーヒーに無関係な様々な器具の説明が表示されます.
- 前回の内容
注
-
apply_chat_template()
は、モデルごとに異なるテンプレートを使用するらしく,使用するモデルに適切なテンプレートが設定されていない場合は、エラーが発生するようです. ↩