はじめに
LangChain の勉強として、プロンプトの実行から評価までの流れを確認したかったので、簡単な推論タスクで試してみました。
今回の実験に使用した主要なコンポーネントのバージョンは以下です。
コンポーネント | バージョン/モデル |
---|---|
LangChain | v0.1.1 |
LLM | Azure OpenAI gpt-4-32k (0613) |
実験計画
採用課題
Prompt Engineering Guide に CoTプロンプトとして例示されているタスクを実験課題として採用します。
このタスクは与えられた7つの数値に含まれる奇数の数値の合計が偶数になるか判定するタスクです。
The odd numbers in this group add up to an even number: 4, 8, 9, 15, 12, 2, 1.
A: Adding all the odd numbers (9, 15, 1) gives 25. The answer is False.
The odd numbers in this group add up to an even number: 17, 10, 19, 4, 8, 12, 24.
A: Adding all the odd numbers (17, 19) gives 36. The answer is True.
The odd numbers in this group add up to an even number: 16, 11, 14, 4, 8, 13, 24.
A: Adding all the odd numbers (11, 13) gives 24. The answer is True.
The odd numbers in this group add up to an even number: 17, 9, 10, 12, 13, 4, 2.
A: Adding all the odd numbers (17, 9, 13) gives 39. The answer is False.
The odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1.
A:
出典: https://www.promptingguide.ai/techniques/cot
作業手順
以下の順序で実験します。
- データセットの準備
- 実装と評価(単純なFew-shot)
- 実装と評価(Few-shot CoT)
まずは思考の過程を与えない単純なFew-shotプロンプトで正解率を確認し、その後Few-shot CoTプロンプトで正解率が改善するか確認してみます。
データセットの準備
学習データとテストデータを用意します。今回のタスクのデータは数値的に算出可能な内容なので、スクリプトを用いて生成することとします。
のちのCoTプロンプトで使用するため、計算過程の値も一緒に出力しておきます。
import pandas as pd
import random
def create_data():
random_numbers = random.choices(range(1, 101), k=7)
odd_numbers = [x for x in random_numbers if x % 2 != 0]
sum_of_odds = sum(odd_numbers)
is_even = sum_of_odds % 2 == 0
return {"input": random_numbers,
"cot_odd_numbers": odd_numbers,
"cot_sum_of_odds": sum_of_odds,
"answer": is_even}
Few-shot学習に使うデータは5件、評価のためのテストデータは10件としてデータを生成します。
train_size = 5
test_size = 10
random.seed(0)
dataset = [create_data() for i in range(train_size+test_size)]
train_data = dataset[:train_size]
test_data = dataset[train_size:]
Few-shot学習用データ
pd.DataFrame(train_data)
input | cot_odd_numbers | cot_sum_of_odds | answer | |
---|---|---|---|---|
0 | [85, 76, 43, 26, 52, 41, 79] | [85, 43, 41, 79] | 248 | True |
1 | [31, 48, 59, 91, 51, 29, 76] | [31, 59, 91, 51, 29] | 261 | False |
2 | [62, 26, 91, 99, 82, 91, 32] | [91, 99, 91] | 281 | False |
3 | [73, 90, 69, 48, 11, 44, 62] | [73, 69, 11] | 153 | False |
4 | [92, 97, 48, 87, 27, 81, 55] | [97, 87, 27, 81, 55] | 347 | False |
テスト用データ
pd.DataFrame(test_data)
input | cot_odd_numbers | cot_sum_of_odds | answer | |
---|---|---|---|---|
0 | [2, 72, 40, 83, 67, 1, 50] | [83, 67, 1] | 151 | False |
1 | [87, 25, 33, 88, 20, 57, 24] | [87, 25, 33, 57] | 202 | True |
2 | [97, 81, 45, 9, 33, 51, 94] | [97, 81, 45, 9, 33, 51] | 316 | True |
3 | [11, 56, 71, 55, 82, 55, 97] | [11, 71, 55, 55, 97] | 289 | False |
4 | [61, 59, 45, 60, 39, 58, 30] | [61, 59, 45, 39] | 204 | True |
5 | [19, 19, 62, 66, 48, 9, 76] | [19, 19, 9] | 47 | False |
6 | [88, 93, 85, 90, 93, 55, 40] | [93, 85, 93, 55] | 326 | True |
7 | [71, 28, 82, 85, 90, 59, 95] | [71, 85, 59, 95] | 310 | True |
8 | [58, 46, 67, 100, 92, 80, 9] | [67, 9] | 76 | True |
9 | [62, 49, 64, 85, 25, 74, 12] | [49, 85, 25] | 159 | False |
実装と評価(単純なFew-shot)
はじめに思考の過程を与えない、単純なFew-shotプロンプトで実験します。
プロンプトの実装
LangChain には Few shot prompt template があるのでこれを利用してFew-shotプロンプトを作成します。
まず PromptTemplate
を用いて先ほど生成した学習データをFew-shot例の文章になるようテンプレート化します。
from langchain.prompts.prompt import PromptTemplate
template_example = """The odd numbers in this group add up to an even number: {input}.
A: The answer is {answer}"""
example_prompt = PromptTemplate(
input_variables=["input", "answer"], template=template_example
)
学習データの1件目を使ってプロンプトを出力すると以下のようになります。
print(example_prompt.format(**train_data[0]))
The odd numbers in this group add up to an even number: [85, 76, 43, 26, 52, 41, 79].
A: The answer is True
次にプロンプト全体を作成するため FewShotPromptTemplate
を使用します。
Few-shot の例示には先ほど作成したプロンプトテンプレートと学習データを使用します。 suffix
にテンプレートを指定して末尾にテストデータ用のプロンプトを挿入します。
from langchain.prompts.few_shot import FewShotPromptTemplate
template_question = "The odd numbers in this group add up to an even number: {input}.\nA: "
prompt_fewshot = FewShotPromptTemplate(
examples=train_data,
example_prompt=example_prompt,
suffix=template_question,
input_variables=["input"],
)
テストデータの1件目を使ってプロンプトを出力すると以下のようになります。
学習データ5件が Few-shot として例示され、最後に解かせたいテストデータの内容が出力されています。LLMにはこの続きを推論させます。
print(prompt_fewshot.format(**test_data[0]))
The odd numbers in this group add up to an even number: [85, 76, 43, 26, 52, 41, 79].
A: The answer is True
The odd numbers in this group add up to an even number: [31, 48, 59, 91, 51, 29, 76].
A: The answer is False
The odd numbers in this group add up to an even number: [62, 26, 91, 99, 82, 91, 32].
A: The answer is False
The odd numbers in this group add up to an even number: [73, 90, 69, 48, 11, 44, 62].
A: The answer is False
The odd numbers in this group add up to an even number: [92, 97, 48, 87, 27, 81, 55].
A: The answer is False
The odd numbers in this group add up to an even number: [2, 72, 40, 83, 67, 1, 50].
A:
チェーンの実装
プロンプトとLLMを結合して実行可能なチェーンとして実装します。
from langchain_openai import AzureChatOpenAI
from langchain_core.output_parsers import StrOutputParser
model = AzureChatOpenAI(deployment_name="gpt-4-32k-0613", temperature=0)
output_parser = StrOutputParser()
chain = prompt_fewshot | model | output_parser
実行できるか確認してみます。
output = chain.invoke(test_data[0])
output
'The answer is True'
推論結果が返ってきました。
この後の評価では True
か False
かの結果だけ欲しいので、正規表現で抽出できるようにしておきます。
import re
RE_SIMPLE = re.compile(r'The answer is (True|False)')
def parse_output(output, matcher):
return matcher.match(output).group(1)
parse_output(output, RE_SIMPLE)
'True'
テストと結果評価
テストデータに対してチェーンを実行します。
from tqdm import tqdm
outputs = []
for i in tqdm(range(len(test_data)), ncols=0):
o = chain.invoke(test_data[i])
outputs.append(o)
outputs
100% 10/10 [00:10<00:00, 1.05s/it]
['The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True']
predicts = [parse_output(o, RE_SIMPLE) for o in outputs]
def merge_result(data, predicts):
return pd.DataFrame([dict(d, predict=p, is_correct=(p == str(d['answer']))) for d, p in zip(data, predicts)])
results = merge_result(test_data, predicts)
results
input | cot_odd_numbers | cot_sum_of_odds | answer | predict | is_correct | |
---|---|---|---|---|---|---|
0 | [2, 72, 40, 83, 67, 1, 50] | [83, 67, 1] | 151 | False | True | False |
1 | [87, 25, 33, 88, 20, 57, 24] | [87, 25, 33, 57] | 202 | True | True | True |
2 | [97, 81, 45, 9, 33, 51, 94] | [97, 81, 45, 9, 33, 51] | 316 | True | True | True |
3 | [11, 56, 71, 55, 82, 55, 97] | [11, 71, 55, 55, 97] | 289 | False | True | False |
4 | [61, 59, 45, 60, 39, 58, 30] | [61, 59, 45, 39] | 204 | True | True | True |
5 | [19, 19, 62, 66, 48, 9, 76] | [19, 19, 9] | 47 | False | True | False |
6 | [88, 93, 85, 90, 93, 55, 40] | [93, 85, 93, 55] | 326 | True | True | True |
7 | [71, 28, 82, 85, 90, 59, 95] | [71, 85, 59, 95] | 310 | True | True | True |
8 | [58, 46, 67, 100, 92, 80, 9] | [67, 9] | 76 | True | True | True |
9 | [62, 49, 64, 85, 25, 74, 12] | [49, 85, 25] | 159 | False | True | False |
def report_accuracy(result_table):
res_all = result_table['is_correct']
res_correct = [x for x in res_all if x == True]
count_all = len(res_all)
count_correct = len(res_correct)
print("Accuracy: {} ({}/{})".format(count_correct / count_all, count_correct, count_all))
report_accuracy(results)
Accuracy: 0.6 (6/10)
2択問題で推論値はすべて同じ回答だったので、正答率は0.6でした。
実装と評価(Few-shot CoT)
プロンプトの実装
今度は Prompt Engineering Guide で紹介されているCoTプロンプトのように、Few-shot example に思考の過程も含めるようにします。
Few-shot用のテンプレートに以下の内容を思考の過程として追加します。
Adding all the odd numbers {cot_odd_numbers} gives {cot_sum_of_odds}.
template_example = """The odd numbers in this group add up to an even number: {input}.
A: Adding all the odd numbers {cot_odd_numbers} gives {cot_sum_of_odds}. The answer is {answer}"""
example_prompt = PromptTemplate(
input_variables=["input", "cot_odd_numbers", "cot_sum_of_odds", "answer"],
template=template_example,
)
prompt_cot = FewShotPromptTemplate(
examples=train_data,
example_prompt=example_prompt,
suffix="The odd numbers in this group add up to an even number: {input}.\nA: ",
input_variables=["input"],
)
print(prompt_cot.format(**test_data[0]))
The odd numbers in this group add up to an even number: [85, 76, 43, 26, 52, 41, 79].
A: Adding all the odd numbers [85, 43, 41, 79] gives 248. The answer is True
The odd numbers in this group add up to an even number: [31, 48, 59, 91, 51, 29, 76].
A: Adding all the odd numbers [31, 59, 91, 51, 29] gives 261. The answer is False
The odd numbers in this group add up to an even number: [62, 26, 91, 99, 82, 91, 32].
A: Adding all the odd numbers [91, 99, 91] gives 281. The answer is False
The odd numbers in this group add up to an even number: [73, 90, 69, 48, 11, 44, 62].
A: Adding all the odd numbers [73, 69, 11] gives 153. The answer is False
The odd numbers in this group add up to an even number: [92, 97, 48, 87, 27, 81, 55].
A: Adding all the odd numbers [97, 87, 27, 81, 55] gives 347. The answer is False
The odd numbers in this group add up to an even number: [2, 72, 40, 83, 67, 1, 50].
A:
チェーンの実装
プロンプトは再作成したものを利用し、それ以外は変更せずそのまま利用します。
chain = prompt_cot | model | output_parser
テストと結果評価
テストデータに対してチェーンを実行します。
from tqdm import tqdm
outputs = []
for i in tqdm(range(len(test_data)), ncols=0):
o = chain.invoke(test_data[i])
outputs.append(o)
outputs
100% 10/10 [00:30<00:00, 3.07s/it]
['Adding all the odd numbers [83, 67, 1] gives 151. The answer is False.',
'Adding all the odd numbers [87, 25, 33, 57] gives 202. The answer is True.',
'Adding all the odd numbers [97, 81, 45, 9, 33, 51] gives 316. The answer is True.',
'Adding all the odd numbers [11, 71, 55, 55, 97] gives 289. The answer is False.',
'Adding all the odd numbers [61, 59, 45, 39] gives 204. The answer is True.',
'Adding all the odd numbers [19, 19, 9] gives 47. The answer is False.',
'Adding all the odd numbers [93, 85, 93, 55] gives 326. The answer is True.',
'Adding all the odd numbers [71, 85, 59, 95] gives 310. The answer is True.',
'Adding all the odd numbers [67, 9] gives 76. The answer is True.',
'Adding all the odd numbers [49, 85, 25] gives 159. The answer is False.']
RE_COT = re.compile(r'Adding all the odd numbers .* gives .*. The answer is (True|False)')
predicts = [parse_output(o, RE_COT) for o in outputs]
results = merge_result(test_data, predicts)
results
input | cot_odd_numbers | cot_sum_of_odds | answer | predict | is_correct | |
---|---|---|---|---|---|---|
0 | [2, 72, 40, 83, 67, 1, 50] | [83, 67, 1] | 151 | False | False | True |
1 | [87, 25, 33, 88, 20, 57, 24] | [87, 25, 33, 57] | 202 | True | True | True |
2 | [97, 81, 45, 9, 33, 51, 94] | [97, 81, 45, 9, 33, 51] | 316 | True | True | True |
3 | [11, 56, 71, 55, 82, 55, 97] | [11, 71, 55, 55, 97] | 289 | False | False | True |
4 | [61, 59, 45, 60, 39, 58, 30] | [61, 59, 45, 39] | 204 | True | True | True |
5 | [19, 19, 62, 66, 48, 9, 76] | [19, 19, 9] | 47 | False | False | True |
6 | [88, 93, 85, 90, 93, 55, 40] | [93, 85, 93, 55] | 326 | True | True | True |
7 | [71, 28, 82, 85, 90, 59, 95] | [71, 85, 59, 95] | 310 | True | True | True |
8 | [58, 46, 67, 100, 92, 80, 9] | [67, 9] | 76 | True | True | True |
9 | [62, 49, 64, 85, 25, 74, 12] | [49, 85, 25] | 159 | False | False | True |
report_accuracy(results)
Accuracy: 1.0 (10/10)
CoTプロンプトに変更したところ、全問正答となりました。
おわりに
LangChain を用いたプロンプトの実行から評価までの流れを確認できました。
今回はひとまず参考にしたプロンプトをあまり変えずに実験してみたかったので、出力を正規表現でパースしましたが、LangChain には各種 Output Parser があり、JSON で応答してもらうようにするなどできるようです。
また評価についても、今回のタスクは単純な2択判定であったため簡単に正答率を算出できましたが、テキスト生成などのタスクではこのような評価はできません。
LangChain には評価に利用できる Evaluator がありますが、これらの機能を使っていくなど、現実世界のタスクでは更なる工夫が必要そうです。