LangChain の勉強として、プロンプトの実行から評価までの流れを確認したかったので、簡単な推論タスクで試してみました。
コンポーネント | バージョン/モデル |
LangChain | v0.1.1 |
LLM | Azure OpenAI gpt-4-32k (0613) |
Prompt Engineering Guide に CoTプロンプトとして例示されているタスクを実験課題として採用します。
The odd numbers in this group add up to an even number: 4, 8, 9, 15, 12, 2, 1.
A: Adding all the odd numbers (9, 15, 1) gives 25. The answer is False.
The odd numbers in this group add up to an even number: 17, 10, 19, 4, 8, 12, 24.
A: Adding all the odd numbers (17, 19) gives 36. The answer is True.
The odd numbers in this group add up to an even number: 16, 11, 14, 4, 8, 13, 24.
A: Adding all the odd numbers (11, 13) gives 24. The answer is True.
The odd numbers in this group add up to an even number: 17, 9, 10, 12, 13, 4, 2.
A: Adding all the odd numbers (17, 9, 13) gives 39. The answer is False.
The odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1.
出典: https://www.promptingguide.ai/techniques/cot
- データセットの準備
- 実装と評価(単純なFew-shot)
- 実装と評価(Few-shot CoT)
まずは思考の過程を与えない単純なFew-shotプロンプトで正解率を確認し、その後Few-shot CoTプロンプトで正解率が改善するか確認してみます。
import pandas as pd
import random
def create_data():
random_numbers = random.choices(range(1, 101), k=7)
odd_numbers = [x for x in random_numbers if x % 2 != 0]
sum_of_odds = sum(odd_numbers)
is_even = sum_of_odds % 2 == 0
return {"input": random_numbers,
"cot_odd_numbers": odd_numbers,
"cot_sum_of_odds": sum_of_odds,
"answer": is_even}
train_size = 5
test_size = 10
dataset = [create_data() for i in range(train_size+test_size)]
train_data = dataset[:train_size]
test_data = dataset[train_size:]
input | cot_odd_numbers | cot_sum_of_odds | answer | |
0 | [85, 76, 43, 26, 52, 41, 79] | [85, 43, 41, 79] | 248 | True |
1 | [31, 48, 59, 91, 51, 29, 76] | [31, 59, 91, 51, 29] | 261 | False |
2 | [62, 26, 91, 99, 82, 91, 32] | [91, 99, 91] | 281 | False |
3 | [73, 90, 69, 48, 11, 44, 62] | [73, 69, 11] | 153 | False |
4 | [92, 97, 48, 87, 27, 81, 55] | [97, 87, 27, 81, 55] | 347 | False |
input | cot_odd_numbers | cot_sum_of_odds | answer | |
0 | [2, 72, 40, 83, 67, 1, 50] | [83, 67, 1] | 151 | False |
1 | [87, 25, 33, 88, 20, 57, 24] | [87, 25, 33, 57] | 202 | True |
2 | [97, 81, 45, 9, 33, 51, 94] | [97, 81, 45, 9, 33, 51] | 316 | True |
3 | [11, 56, 71, 55, 82, 55, 97] | [11, 71, 55, 55, 97] | 289 | False |
4 | [61, 59, 45, 60, 39, 58, 30] | [61, 59, 45, 39] | 204 | True |
5 | [19, 19, 62, 66, 48, 9, 76] | [19, 19, 9] | 47 | False |
6 | [88, 93, 85, 90, 93, 55, 40] | [93, 85, 93, 55] | 326 | True |
7 | [71, 28, 82, 85, 90, 59, 95] | [71, 85, 59, 95] | 310 | True |
8 | [58, 46, 67, 100, 92, 80, 9] | [67, 9] | 76 | True |
9 | [62, 49, 64, 85, 25, 74, 12] | [49, 85, 25] | 159 | False |
LangChain には Few shot prompt template があるのでこれを利用してFew-shotプロンプトを作成します。
まず PromptTemplate
from langchain.prompts.prompt import PromptTemplate
template_example = """The odd numbers in this group add up to an even number: {input}.
A: The answer is {answer}"""
example_prompt = PromptTemplate(
input_variables=["input", "answer"], template=template_example
The odd numbers in this group add up to an even number: [85, 76, 43, 26, 52, 41, 79].
A: The answer is True
次にプロンプト全体を作成するため FewShotPromptTemplate
Few-shot の例示には先ほど作成したプロンプトテンプレートと学習データを使用します。 suffix
from langchain.prompts.few_shot import FewShotPromptTemplate
template_question = "The odd numbers in this group add up to an even number: {input}.\nA: "
prompt_fewshot = FewShotPromptTemplate(
学習データ5件が Few-shot として例示され、最後に解かせたいテストデータの内容が出力されています。LLMにはこの続きを推論させます。
The odd numbers in this group add up to an even number: [85, 76, 43, 26, 52, 41, 79].
A: The answer is True
The odd numbers in this group add up to an even number: [31, 48, 59, 91, 51, 29, 76].
A: The answer is False
The odd numbers in this group add up to an even number: [62, 26, 91, 99, 82, 91, 32].
A: The answer is False
The odd numbers in this group add up to an even number: [73, 90, 69, 48, 11, 44, 62].
A: The answer is False
The odd numbers in this group add up to an even number: [92, 97, 48, 87, 27, 81, 55].
A: The answer is False
The odd numbers in this group add up to an even number: [2, 72, 40, 83, 67, 1, 50].
from langchain_openai import AzureChatOpenAI
from langchain_core.output_parsers import StrOutputParser
model = AzureChatOpenAI(deployment_name="gpt-4-32k-0613", temperature=0)
output_parser = StrOutputParser()
chain = prompt_fewshot | model | output_parser
output = chain.invoke(test_data[0])
'The answer is True'
この後の評価では True
か False
import re
RE_SIMPLE = re.compile(r'The answer is (True|False)')
def parse_output(output, matcher):
return matcher.match(output).group(1)
parse_output(output, RE_SIMPLE)
from tqdm import tqdm
outputs = []
for i in tqdm(range(len(test_data)), ncols=0):
o = chain.invoke(test_data[i])
100% 10/10 [00:10<00:00, 1.05s/it]
['The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True',
'The answer is True']
predicts = [parse_output(o, RE_SIMPLE) for o in outputs]
def merge_result(data, predicts):
return pd.DataFrame([dict(d, predict=p, is_correct=(p == str(d['answer']))) for d, p in zip(data, predicts)])
results = merge_result(test_data, predicts)
input | cot_odd_numbers | cot_sum_of_odds | answer | predict | is_correct | |
0 | [2, 72, 40, 83, 67, 1, 50] | [83, 67, 1] | 151 | False | True | False |
1 | [87, 25, 33, 88, 20, 57, 24] | [87, 25, 33, 57] | 202 | True | True | True |
2 | [97, 81, 45, 9, 33, 51, 94] | [97, 81, 45, 9, 33, 51] | 316 | True | True | True |
3 | [11, 56, 71, 55, 82, 55, 97] | [11, 71, 55, 55, 97] | 289 | False | True | False |
4 | [61, 59, 45, 60, 39, 58, 30] | [61, 59, 45, 39] | 204 | True | True | True |
5 | [19, 19, 62, 66, 48, 9, 76] | [19, 19, 9] | 47 | False | True | False |
6 | [88, 93, 85, 90, 93, 55, 40] | [93, 85, 93, 55] | 326 | True | True | True |
7 | [71, 28, 82, 85, 90, 59, 95] | [71, 85, 59, 95] | 310 | True | True | True |
8 | [58, 46, 67, 100, 92, 80, 9] | [67, 9] | 76 | True | True | True |
9 | [62, 49, 64, 85, 25, 74, 12] | [49, 85, 25] | 159 | False | True | False |
def report_accuracy(result_table):
res_all = result_table['is_correct']
res_correct = [x for x in res_all if x == True]
count_all = len(res_all)
count_correct = len(res_correct)
print("Accuracy: {} ({}/{})".format(count_correct / count_all, count_correct, count_all))
Accuracy: 0.6 (6/10)
実装と評価(Few-shot CoT)
今度は Prompt Engineering Guide で紹介されているCoTプロンプトのように、Few-shot example に思考の過程も含めるようにします。
Adding all the odd numbers {cot_odd_numbers} gives {cot_sum_of_odds}.
template_example = """The odd numbers in this group add up to an even number: {input}.
A: Adding all the odd numbers {cot_odd_numbers} gives {cot_sum_of_odds}. The answer is {answer}"""
example_prompt = PromptTemplate(
input_variables=["input", "cot_odd_numbers", "cot_sum_of_odds", "answer"],
prompt_cot = FewShotPromptTemplate(
suffix="The odd numbers in this group add up to an even number: {input}.\nA: ",
The odd numbers in this group add up to an even number: [85, 76, 43, 26, 52, 41, 79].
A: Adding all the odd numbers [85, 43, 41, 79] gives 248. The answer is True
The odd numbers in this group add up to an even number: [31, 48, 59, 91, 51, 29, 76].
A: Adding all the odd numbers [31, 59, 91, 51, 29] gives 261. The answer is False
The odd numbers in this group add up to an even number: [62, 26, 91, 99, 82, 91, 32].
A: Adding all the odd numbers [91, 99, 91] gives 281. The answer is False
The odd numbers in this group add up to an even number: [73, 90, 69, 48, 11, 44, 62].
A: Adding all the odd numbers [73, 69, 11] gives 153. The answer is False
The odd numbers in this group add up to an even number: [92, 97, 48, 87, 27, 81, 55].
A: Adding all the odd numbers [97, 87, 27, 81, 55] gives 347. The answer is False
The odd numbers in this group add up to an even number: [2, 72, 40, 83, 67, 1, 50].
chain = prompt_cot | model | output_parser
from tqdm import tqdm
outputs = []
for i in tqdm(range(len(test_data)), ncols=0):
o = chain.invoke(test_data[i])
100% 10/10 [00:30<00:00, 3.07s/it]
['Adding all the odd numbers [83, 67, 1] gives 151. The answer is False.',
'Adding all the odd numbers [87, 25, 33, 57] gives 202. The answer is True.',
'Adding all the odd numbers [97, 81, 45, 9, 33, 51] gives 316. The answer is True.',
'Adding all the odd numbers [11, 71, 55, 55, 97] gives 289. The answer is False.',
'Adding all the odd numbers [61, 59, 45, 39] gives 204. The answer is True.',
'Adding all the odd numbers [19, 19, 9] gives 47. The answer is False.',
'Adding all the odd numbers [93, 85, 93, 55] gives 326. The answer is True.',
'Adding all the odd numbers [71, 85, 59, 95] gives 310. The answer is True.',
'Adding all the odd numbers [67, 9] gives 76. The answer is True.',
'Adding all the odd numbers [49, 85, 25] gives 159. The answer is False.']
RE_COT = re.compile(r'Adding all the odd numbers .* gives .*. The answer is (True|False)')
predicts = [parse_output(o, RE_COT) for o in outputs]
results = merge_result(test_data, predicts)
input | cot_odd_numbers | cot_sum_of_odds | answer | predict | is_correct | |
0 | [2, 72, 40, 83, 67, 1, 50] | [83, 67, 1] | 151 | False | False | True |
1 | [87, 25, 33, 88, 20, 57, 24] | [87, 25, 33, 57] | 202 | True | True | True |
2 | [97, 81, 45, 9, 33, 51, 94] | [97, 81, 45, 9, 33, 51] | 316 | True | True | True |
3 | [11, 56, 71, 55, 82, 55, 97] | [11, 71, 55, 55, 97] | 289 | False | False | True |
4 | [61, 59, 45, 60, 39, 58, 30] | [61, 59, 45, 39] | 204 | True | True | True |
5 | [19, 19, 62, 66, 48, 9, 76] | [19, 19, 9] | 47 | False | False | True |
6 | [88, 93, 85, 90, 93, 55, 40] | [93, 85, 93, 55] | 326 | True | True | True |
7 | [71, 28, 82, 85, 90, 59, 95] | [71, 85, 59, 95] | 310 | True | True | True |
8 | [58, 46, 67, 100, 92, 80, 9] | [67, 9] | 76 | True | True | True |
9 | [62, 49, 64, 85, 25, 74, 12] | [49, 85, 25] | 159 | False | False | True |
Accuracy: 1.0 (10/10)
LangChain を用いたプロンプトの実行から評価までの流れを確認できました。
今回はひとまず参考にしたプロンプトをあまり変えずに実験してみたかったので、出力を正規表現でパースしましたが、LangChain には各種 Output Parser があり、JSON で応答してもらうようにするなどできるようです。
LangChain には評価に利用できる Evaluator がありますが、これらの機能を使っていくなど、現実世界のタスクでは更なる工夫が必要そうです。