1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

LangChainのReActを使って数学の問題を解いてみた

Posted at

はじめに

今回はLangChainに実装されているReActを使用して、中学校数学の連立方程式の文章題をどれくらいの精度で解くことができるか検証を行いました。

ReActとは

ReActは、大規模言語モデル(LLM)と様々な外部ツールと連携なさせながら複雑なタスクを解くための、プロンプトエンジニアリングの手法です。(詳細はこちらの論文をご参照ください)
ReActでは、解きたいタスクに対して必要となるアクション(どのツールを、どのような入力を与えて使用するか)をLLMに出力してもらいます。また、そのアクションを外部で実行し、その結果得られる観測に基づいて次のアクションをLLMに出力させます。この操作をLLMが最終回答を出力するまで繰り返します。つまり、与えられたタスクを達成するために必要な意思決定をLLMに全て任せてしまうといったイメージになります。

ReActで連立方程式の文章題を解く

概要

現状、LLMは複雑な計算問題を解くことは不得意と言われており、問題文章を単純にLLMに入力するだけでは、正確な回答を得ることは難しいです。そこで、今回はReactを活用しました。事前に与えられた問題を正確に解くために、LangChainの"PythonREPLTool"(Pythonコード実行ツール)を使用します。そして、与えられた問題に対して、その問題を解くために必要なPythonコードをLLMに生成してもらいます。そのコードを外部で実行し、実行結果(Observation)をプロンプトに追加して、LLMに次にとるべきアクションを考えてもらう、というのがReActを使って連立方程式を解く一連の流れになります。

出力例

LangChainのReActを使用して、問題を解いた場合を出力例です。

  • 問題

1本120円の缶ジュースと 1本150円のペットボトルを合わせて10本買ったところ、代金は全部で1320円でした。 缶ジュースとペットボトルをそれぞれ何本買いましたか。

  • ReActの出力例

Thought:この問題は連立方程式を解くことで解決できます。缶ジュースの本数をx本、ペットボトルの本数をy本と仮定します。
まず、x+y=10という式が成り立ちます。また、120x+150y=1320という式も成り立ちます。
以上の2つの式を解くことで、缶ジュースとペットボトルの本数を求めることができます。
Action: Python_REPL
Action Input: from sympy import symbols, solve; x, y = symbols('x y'); answer = solve([x + y - 10, 120x + 150y - 1320], (x,y));print(answer)
Observation: {x: 6, y: 4}
Thought:缶ジュース6本、ペットボトル4本を買いました。
Final Answer: 缶ジュース6本、ペットボトル4本を買いました。

上記の例では、与えられた問題に対して、LLMは考えて(Thought)、アクション(Action, Action Input)を出力し、アクションを外部で実行し結果(Observation)を得ます。(Observationは実際にAction Inputとして出力されたPythonコードを実行した結果です。)そして、得られた結果に基づき、また考えて(Thought)、最終回答(Final Answer)を出力しています。

今回扱う問題について

今回の検証では、以下サイトの中学2年生用の連立方程式の文章題を使用させていただきました。難易度としては、比較的簡単なものから人間でも少し考える必要がある難しい問題まで、様々な難易度の問題が含まれるように問題を選定しています。
https://chugaku.manabihiroba.net/math/renritubunsyo1.html
検証に使用する問題は以下の15問としました。

import pandas as pd
from IPython.display import Markdown, display

question_file = "連立方程式文章題_ver2.csv"
df = pd.read_csv(f"../../data/{question_file}")
df.columns = ["問題", "答え"]
df["問題"] = df["問題"].str.replace("\n", "")
display(Markdown(df.to_markdown()))
問題 答え
0 1本120円の缶ジュースと 1本150円のペットボトルを合わせて10本買ったところ、代金は全部で1320円でした。 缶ジュースとペットボトルをそれぞれ何本買いましたか。 缶ジュース6本、ペットボトル4本
1 ある博物館に入る時、中学生2人と大人3人では1410円、中学生4人と大人5人では2470円 かかります。中学生1人、大人1人の入館料をそれぞれ求めなさい。 中学生180円、大人350円
2 姉と妹で116個のおはじきを分けたところ、姉のおはじきの個数は妹のおはじきの個数の3倍より12個多かった。姉と妹のおはじきの個数をそれぞれ求めなさい 姉 90個、妹 26個
3 2けたの自然数がある。十の位と一の位の数の和は7で、十の位の数字と一の位の数字を入れかえてできる数は、もとの数より27小さいという。もとの自然数を求めよ。 もとの数は52
4 みきさんのお母さんの年齢は みきさんの3倍より1歳多いです。15年後にはお母さんの年齢はみきさんの2倍になります。現在のお母さんとみきさんの年齢を求めなさい。 母 43歳、みき14歳
5 30人のクラスで数学のテストをしたら、男子 の平均点は70点、女子の平均点は80点でした。クラス全体の平均点が 76点のとき、男女の人数は何人ずつになりますか 男子 12人、女子 18人
6 砂糖が入った袋と塩が入った袋がある。 砂糖の袋 3個と 塩の袋2個の重さの合計は380g 、 砂糖の袋 4個と塩の袋 6個の重さの合計は640g である。砂糖の入った袋、塩の入った袋1個の重さはそれぞれ何gか。 砂糖の袋 100g、塩の袋 40g
7 8%の食塩水と3%の食塩水 を混ぜると 6%の食塩水が600g出来ました。食塩水はそれぞれ何gありましたか。 8%の食塩水が360g、 6%の食塩水が240g
8 ある中学校の生徒数は、2年前は男女合わせて600人だった。今年は2年前と比べて、男子は15%の減少、女子は10%の増加で、男女合わせると5%の減少であった。今年の男子、女子の人数はそれぞれ何人か。 男子 306人、女子 264人
9 ふもとから山頂まで、分速 40mで上るのと 同じ道を山頂からふもとまで、分速 80mで下るのとでは、かかる時間が 20分違います。ふもとから山頂までは何mあるでしょうか。 1600m
10 峠をはさんで 18km はなれたA地点とB地点があります。A地点から峠まで行くのに時速 3km , 峠からB地点までは時速 5kmで歩いて、全体で4時間かかりました。A地点から峠までと、B地点から峠まで、それぞれ何kmかを求めなさい。 A地点から峠まで3km、峠からB地点まで15km
11 A町からB町まで、行きは分速80m、帰りは分速60mで歩いたところ全部で35分かかりました。 A町からB町までの道のりを求めなさい。 1200m
12 弟が分速70mの速さで家を出発しました。 その 9分後に、兄が分速100mの速さで自転車で弟を追いかけました。 兄が弟に追いつくのは兄が出発してから何分後ですか。 21分後
13 ある電車が長さ420m のトンネルを通過するのに 30秒かかり、長さ233mの鉄橋をわたるのに 19秒かかりました。この電車の長さをもとめなさい。 90m
14 まわりの長さが1500mの池を、AとBは同じ所を出発して、反対の方向にまわると5分後に出会い、 AとBが同じ方向にまわると30分後にAがBに追いつきました。 AとBの速さをそれぞれ求めなさい。 Aの速さ分速 175m、Bの速さ分速125m

ReAct実装

以下がLangChainを使用したReActエージェントの実装になります。
なお、LangChainのバージョンは0.0.237です。

Pythonコード実行ツールのカスタマイズ

LangChainに実装されているPythonREPLToolの説明文章をカスタマイズします。この説明文章は、入力プロンプト内に埋め込まれる形で、LLMに対してがツールの使い方を提示する役割を担います。今回、カスタマイズを行なった理由は、デフォルトの説明文章で試したところ、LLMが出力したコードの最後にprint(...)文が含まれず、出力結果をObservationとして得ることができないケースが散見されたためです。そのため、ツールの説明文章に"you should use print(...)"を含めることで、実行結果を結果を出力させるように工夫しました。加えて、今回のタスクに関連して、連立方程式を解きたい場合のコードサンプルを含めています。

from langchain.agents import AgentType, Tool, initialize_agent
from langchain.agents.agent import AgentExecutor
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.tools.python.tool import PythonREPLTool
# LangChainの実装されているPythonREPLToolの説明文章を変更する
# 出力結果をprintするように指示
# 説明文章内で今回解かせたい連立方程式を解くためのサンプルコードを例示

python_tool = PythonREPLTool()
new_desc = (
    "A Python shell. Use this to execute python commands. "
    "Input should be a valid python command. In last commands line, you should use `print(...)` to check the execution result.\n"
    "For example, if you want to solve next equations,\n"
    "x + y = 100\n"
    "20x + 10y = 2500\n"
    '"action_input" should be next command.\n'
    '"action_input": "from sympy import symbols, solve; x, y = symbols(\'x y\'); answer = solve([x + y - 100, 20*x + 10*y - 2500], (x,y));print(answer)"'
)
python_tool.description = new_desc

計算エージェントの定義

現在わかっている情報(与えられたインプットやこれまでのThoughtObservation)に基づき、次のアクション(Action)を決定するのが、こちらの計算エージェントです。
"PythonREPLTool"を使用する場合に、Pythonコードを生成する役割も担います。ベースとなるLLMはOpenAIのChatGPT*(モデル名:"gpt-3.5-turbo-0613")を使用しています。
※このモデルを使用するには、OpenAIのアカウントを作成しAPIキーを発行する必要がありますので、ご注意ください。

def _handle_error(error) -> str:
    if "Action:" not in str(error):
        message = "'action' should be in your output. "
    else:
        message = "Check your output and make sure it conforms!"

    return message
agent_executor_kwargs = {
    "max_iterations": 3,
    "early_stopping_method": "generate",
    "handle_parsing_errors": _handle_error,
}

# 使用するツール(今回はpythonコード実行ツールのみ)
tools = [python_tool]

# ベースモデルとしてChatGPTを使用
llm = ChatOpenAI(temperature=0.0, max_tokens=1000, model="gpt-3.5-turbo-0613")

# 計算エージェント
agent_executor = initialize_agent(
    tools=tools,
    llm=llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=False,
    **agent_executor_kwargs
)

翻訳チェーンの定義

ChatGPTですが、一般的には入力を英語とした方が精度が良いと言われていますので、今回の検証では、コードを生成する計算エージェントへの入力を事前に英語に変換しています。入力が英語の場合、出力もほとんどの場合英語となりますので、得られた英語出力を最後に日本語に変換することで最終回答を得るようにします。これを実現するため、計算エージェントの前後に日本語→英語の翻訳チェーンおよび英語→日本語の翻訳チェーンを連結します。

from langchain.prompts import PromptTemplate

# 翻訳プロンプト
prompt = PromptTemplate(
    input_variables=["text", "language1", "language2"],
    template="Please translate next {language1} text to {language2}\n:{language1}{text}\n{language2}:",
)
# 日本語→英語に変換するチェーン
trans_chain_j2e = LLMChain(
    llm=llm, prompt=prompt.partial(language1="Japanese", language2="English")
)
# 英語⇨日本語に変換するチェーン
trans_chain_e2j = LLMChain(
    llm=llm, prompt=prompt.partial(language1="English", language2="Japanese")
)

全体的なチェーンの定義

上記で定義した翻訳チェーンと計算エージェントを連結し、全体的なチェーンを定義します。このようなチェーンの連結はLangChainの"SimpleSequentialChain"を使用することで簡単に実装できます。
コンポーネントは以下の通りです。

  1. 翻訳(日本語→英語)
  2. 計算エージェント
  3. 翻訳(英語→日本語)
from langchain.chains import SimpleSequentialChain

overall_chain = SimpleSequentialChain(
    chains=[trans_chain_j2e, agent_executor, trans_chain_e2j], verbose=False
)

予測

上記で定義したチェーンを使用して、問題リストに対して予測を行います。LLMの出力は"モデル出力"としてデータフレームに格納しています。

preds = overall_chain.apply(df["問題"].tolist())
df["モデル出力"] = [x["output"] for x in preds]

LLMを使用した自動回答チェック

モデルが出力と答えを照らし合わせて正誤チェックする手間を省くため、LLMを使って自動採点させるようにしました。「問題」、「答え」、「モデル回答」の3点をLLMに与えて、「答え」と「モデル回答」があっているかどうかを判定してもらいます。実際に試してみたところ、採点精度は完璧とは言えませんが、それなりに高い精度で採点をしてくれました。このようなスコアリングの自動化は、回答精度を高めるために入力プロンプトをチューニングをするときなど、試行錯誤を通じた改善を行う際は有用と考えます。

# 回答チェックチェーン
# 正解とモデル出力を比較して正誤を判断する
from langchain.prompts import PromptTemplate

# 回答チェック用のプロンプト
checker_template = """
You are a math teacher and need to check if the answers given by your students are correct.
Next, you will be given a qeuestion, the correct answer and the student's answer, and you must answer "OK" if the student's answer is correct, or "NG" if the student's answer is incorrect.\n
- Question:{question}
- Correct answer:{correct_answer}
- Student's answer:{prediction}

"OK" or "NG":
"""

checker_prompt = PromptTemplate(
    input_variables=["question", "correct_answer", "prediction"],
    template=checker_template,
)
# 回答チェッカー
checker_chain = LLMChain(llm=llm, prompt=checker_prompt, verbose=False)
answers_and_outputs = [
    {
        "question": row["問題"],
        "correct_answer": row["答え"],
        "prediction": row["モデル出力"],
    }
    for index, row in df.iterrows()
]
checker_preds = checker_chain.apply(answers_and_outputs)
df["チェッカー出力"] = [x["text"] for x in checker_preds]
df["チェッカー出力"] = df["チェッカー出力"].str.replace("\n","")
df["モデル出力"] = df["モデル出力"].str.replace("\n","")

モデル回答一覧

以下が全ての問題に対する出力結果と正誤チェックの結果です。自動採点の結果、15問中8問正解しており、正解率は約50%でした。

display(Markdown(df.to_markdown()))
問題 答え モデル出力 チェッカー出力
0 1本120円の缶ジュースと 1本150円のペットボトルを合わせて10本買ったところ、代金は全部で1320円でした。 缶ジュースとペットボトルをそれぞれ何本買いましたか。 缶ジュース6本、ペットボトル4本 私は缶6本とボトル4本を買いました。 OK
1 ある博物館に入る時、中学生2人と大人3人では1410円、中学生4人と大人5人では2470円 かかります。中学生1人、大人1人の入館料をそれぞれ求めなさい。 中学生180円、大人350円 1人の中学生の入場料は180円で、1人の大人の入場料は350円です。 OK
2 姉と妹で116個のおはじきを分けたところ、姉のおはじきの個数は妹のおはじきの個数の3倍より12個多かった。姉と妹のおはじきの個数をそれぞれ求めなさい 姉 90個、妹 26個 妹は26個のビー玉を持っており、姉は90個のビー玉を持っています。 OK
3 2けたの自然数がある。十の位と一の位の数の和は7で、十の位の数字と一の位の数字を入れかえてできる数は、もとの数より27小さいという。もとの自然数を求めよ。 もとの数は52 元の自然数は52です。 OK
4 みきさんのお母さんの年齢は みきさんの3倍より1歳多いです。15年後にはお母さんの年齢はみきさんの2倍になります。現在のお母さんとみきさんの年齢を求めなさい。 母 43歳、みき14歳 ミキの現在の年齢は14歳で、彼女の母親の現在の年齢は43歳です。 OK
5 30人のクラスで数学のテストをしたら、男子 の平均点は70点、女子の平均点は80点でした。クラス全体の平均点が 76点のとき、男女の人数は何人ずつになりますか 男子 12人、女子 18人 クラスには男子12人と女子18人がいます。 NG
6 砂糖が入った袋と塩が入った袋がある。 砂糖の袋 3個と 塩の袋2個の重さの合計は380g 、 砂糖の袋 4個と塩の袋 6個の重さの合計は640g である。砂糖の入った袋、塩の入った袋1個の重さはそれぞれ何gか。 砂糖の袋 100g、塩の袋 40g 砂糖の袋一つの重さは100gで、塩の袋一つの重さは40gです。 OK
7 8%の食塩水と3%の食塩水 を混ぜると 6%の食塩水が600g出来ました。食塩水はそれぞれ何gありましたか。 8%の食塩水が360g、 6%の食塩水が240g 英語:There were 360g of 8% saline solution and 240g of 3% saline solution.日本語:8%の塩水溶液が360gあり、3%の塩水溶液が240gありました。 OK
8 ある中学校の生徒数は、2年前は男女合わせて600人だった。今年は2年前と比べて、男子は15%の減少、女子は10%の増加で、男女合わせると5%の減少であった。今年の男子、女子の人数はそれぞれ何人か。 男子 306人、女子 264人 今年の学校には360人の男子生徒と240人の女子生徒がいます。 NG
9 ふもとから山頂まで、分速 40mで上るのと 同じ道を山頂からふもとまで、分速 80mで下るのとでは、かかる時間が 20分違います。ふもとから山頂までは何mあるでしょうか。 1600m 日本語:1600メートル OK
10 峠をはさんで 18km はなれたA地点とB地点があります。A地点から峠まで行くのに時速 3km , 峠からB地点までは時速 5kmで歩いて、全体で4時間かかりました。A地点から峠までと、B地点から峠まで、それぞれ何kmかを求めなさい。 A地点から峠まで3km、峠からB地点まで15km ポイントAから峠までの距離は1kmであり、ポイントBから峠までの距離は3kmです。 NG
11 A町からB町まで、行きは分速80m、帰りは分速60mで歩いたところ全部で35分かかりました。 A町からB町までの道のりを求めなさい。 1200m 町Aから町Bまでの距離は1400メートルです。 NG
12 弟が分速70mの速さで家を出発しました。 その 9分後に、兄が分速100mの速さで自転車で弟を追いかけました。 兄が弟に追いつくのは兄が出発してから何分後ですか。 21分後 日本語21 NG
13 ある電車が長さ420m のトンネルを通過するのに 30秒かかり、長さ233mの鉄橋をわたるのに 19秒かかりました。この電車の長さをもとめなさい。 90m 450メートル NG
14 まわりの長さが1500mの池を、AとBは同じ所を出発して、反対の方向にまわると5分後に出会い、 AとBが同じ方向にまわると30分後にAがBに追いつきました。 AとBの速さをそれぞれ求めなさい。 Aの速さ分速 175m、Bの速さ分速125m AとBの速度はそれぞれ50 m/minと100 m/minです。 NG
acc = df.loc[df["チェッカー出力"] == "OK"].size / df.size
print("正解率:", acc)
正解率: 0.5333333333333333

所感

今回はLangChainのReAct実装を使って、連立方程式の問題を解くことにチャレンジしてみました。正解率は50%とそれほど良くはなかったですが、外部ツールと組み合わせることで複雑なタスクを解くことができるというLLMの可能性を感じることができました。少数の例示をプロンプトに与えるFew-Shot設定等、さらに回答精度を上げる余地はあると思いますので、また今度チャレンジしてみたいと思います。最後までお読みいただきありがとうございました。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?