導入
LangGraphのサンプルを眺めていたら、Self-Discoverを実践するノートブックがありましたので、若干魔改造してウォークスルーしてみます。
Self-Discoverって?
手前みそですが、以下の記事を参照ください。
この記事ではSGLangを使ってシンプルに実現してみたものを、LangGraph使ったらどうなるかという内容となります。
※ 結論から言うと、結構冗長だなあという印象です。
実装はDatabricks on AWS上で行いました。
DBRは14.3ML、クラスタタイプはg5.xlarge(GPUクラスタ)です。
Step1. パッケージインストール
LangGraphのサンプルでは推論部分はOpenAIのサービスを利用していましたが、今回は(も)SGLangでローカルLLMを利用することで実施します。ただ、SGLangの固有機能は利用しないため、OpenAIのままでも、ollamaやvLLMなど好きな推論エンジンを使ってローカルLLMを利用するでも問題ないかと思います。
ノートブックを作成し、SGLangに必要なパッケージをインストール。
# torch, xformers
# pytorchのリポジトリから直接インストールする場合
# %pip install -U https://download.pytorch.org/whl/cu118/torch-2.1.2%2Bcu118-cp310-cp310-linux_x86_64.whl
# %pip install -U https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/torch-2.1.2+cu118-cp310-cp310-linux_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/xformers-0.0.23.post1+cu118-cp310-cp310-manylinux2014_x86_64.whl
# vLLM
# GithubからvLLMを直接インストールする場合
# %pip install https://github.com/vllm-project/vllm/releases/download/v0.3.2/vllm-0.3.2+cu118-cp310-cp310-manylinux1_x86_64.whl
%pip install /Volumes/training/llm/tmp/vllm-0.3.2+cu118-cp310-cp310-manylinux1_x86_64.whl
# sglang==0.1.12の場合、outlines>0.030だとエラーが出るためoutlinesのバージョンを固定
%pip install "outlines>=0.0.27,<=0.0.30"
%pip install "sglang[srt]>=0.1.12"
%pip install "triton>=2.2.0"
# LangGraphなど必要なパッケージをインストール
%pip install -U langchain langchain-openai langgraph langchainhub
dbutils.library.restartPython()
合わせてpytorchのmultiprocessing設定を変更。
import torch
torch.multiprocessing.set_start_method('spawn', force=True)
Step2. モデルのロード
ユーザ指示に対する生成と批評、両方を行うLLMをSGLangでロードします。
今回はAWQフォーマットで量子化されたQwen1.5 14Bを利用しました。
モデルは事前にダウンロードしたものを利用しています。
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
from sglang.lang.chat_template import (
get_chat_template,
register_chat_template,
ChatTemplate,
)
# チャット用のプロンプトテンプレート
register_chat_template(
ChatTemplate(
name="qwen",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
)
)
model_path = (
"/Volumes/training/llm/model_snapshots/models--Qwen--Qwen1.5-14B-Chat-AWQ"
)
runtime = Runtime(model_path, mem_fraction_static=0.8)
runtime.endpoint.chat_template = get_chat_template("qwen")
set_default_backend(runtime)
ここまではSGLang単体でのSelf-Discover実践とほとんど同じコードです。
Step3. プロンプトテンプレートの準備
SELECT/ADAPT/IMPLEMENT(サンプルノートブックではSTRUCTUREと表現されていました)のプロンプトテンプレートを定義します。
LangSmithのhubに公開されているSelf-Discover用のプロンプトテンプレートを取得して利用します。
まずはSELECT部分。取得と合わせて中身も確認します。
from langchain import hub
from langchain_core.prompts import PromptTemplate
select_prompt = hub.pull("hwchase17/self-discovery-select")
select_prompt.pretty_print()
Select several reasoning modules that are crucial to utilize in order to solve the given task:
All reasoning module descriptions:
{reasoning_modules}
Task: {task_description}
Select several modules are crucial for solving the task above:
reasoning_modules
の中からいくつかタスクの解決に決定的なものを選ぶようなプロンプトですね。
次にADAPT部分。
adapt_prompt = hub.pull("hwchase17/self-discovery-adapt")
adapt_prompt.pretty_print()
Rephrase and specify each reasoning module so that it better helps solving the task:
SELECTED module descriptions:
{selected_modules}
Task: {task_description}
Adapt each reasoning module description to better solve the task:
selected_modules
(SELECT部の選択結果が挿入されます)をタスク解決により適した形に変更する指示内容となっています。
次にIMPLEMENT(STRUCTURE)部分。
structured_prompt = hub.pull("hwchase17/self-discovery-structure")
structured_prompt.pretty_print()
Operationalize the reasoning modules into a step-by-step reasoning plan in JSON format:
Here's an example:
Example task:
If you follow these instructions, do you return to the starting point? Always face forward. Take 1 step backward. Take 9 steps left. Take 2 steps backward. Take 6 steps forward. Take 4 steps forward. Take 4 steps backward. Take 3 steps right.
Example reasoning structure:
{
"Position after instruction 1":
"Position after instruction 2":
"Position after instruction n":
"Is final position the same as starting position":
}
Adapted module description:
{adapted_modules}
Task: {task_description}
Implement a reasoning structure for solvers to follow step-by-step and arrive at correct answer.
Note: do NOT actually arrive at a conclusion in this pass. Your job is to generate a PLAN so that in the future you can fill it out and arrive at the correct conclusion for tasks like this
adapted_modules
(ADAPT部の結果が挿入されます)をタスク解決のためのプランとして構成し直す指示が出ています。
最後に、IMPLEMENT部の結果を使って実際の推論実行をするためのプロンプト。タスクを解決するように指示されています。
reasoning_prompt = hub.pull("hwchase17/self-discovery-reasoning")
reasoning_prompt.pretty_print()
Follow the step-by-step reasoning plan in JSON to correctly solve the task. Fill in the values following the keys by reasoning specifically about the task given. Do not simply rephrase the keys.
Reasoning Structure:
{reasoning_structure}
Task: {task_description}
Step4. グラフの構成
LangGraphを使ってSelf-Discoverを実行するためのグラフ構造を定義します。
グラフ状態クラスの定義
まずはグラフ内で使う状態クラスの定義。
シンプルに各ステップの結果を格納するようになっています。
from typing import TypedDict, Optional
class SelfDiscoverState(TypedDict):
reasoning_modules: str
task_description: str
selected_modules: Optional[str]
adapted_modules: Optional[str]
reasoning_structure: Optional[str]
answer: Optional[str]
グラフノードの定義
各ノードを定義します。
最初に、SGLangで推論するための関数を定義。
その上で、LangChainのLCELでその処理を利用するためにRunnableLambda
でラップします。
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
@function
def simple_question(s, question):
s += user(question)
s += assistant(gen("answer", max_tokens=1024))
def call_llm(input):
state = simple_question.run(
question=input.text,
temperature=0,
)
return state["answer"]
# SGLangの推論実行をRunnableLambdaでラップ。これでLCELのChain定義に利用できる
model = RunnableLambda(call_llm)
次に各ノード用の関数を定義。
SELECT部/ADAPT部/IMPLEMENT(STRUCT)部、そして推論(reasoning)のノードを定義しています。
各関数はStep3で作成したプロンプトテンプレートに則って推論するChainを作成・事項しているだけのシンプルな内容です。
def select(inputs):
select_chain = select_prompt | model | StrOutputParser()
return {"selected_modules": select_chain.invoke(inputs)}
def adapt(inputs):
adapt_chain = adapt_prompt | model | StrOutputParser()
return {"adapted_modules": adapt_chain.invoke(inputs)}
def structure(inputs):
structure_chain = structured_prompt | model | StrOutputParser()
return {"reasoning_structure": structure_chain.invoke(inputs)}
def reason(inputs):
reasoning_chain = reasoning_prompt | model | StrOutputParser()
return {"answer": reasoning_chain.invoke(inputs)}
グラフ構築
最後に、グラフ定義です。各ノードを直接につなげただけのシンプルなグラフになります。
特に条件分岐等はありません。
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional
graph = StateGraph(SelfDiscoverState)
graph.add_node("select", select)
graph.add_node("adapt", adapt)
graph.add_node("structure", structure)
graph.add_node("reason", reason)
graph.add_edge("select", "adapt")
graph.add_edge("adapt", "structure")
graph.add_edge("structure", "reason")
graph.add_edge("reason", END)
graph.set_entry_point("select")
app = graph.compile()
Step5. 推論の実行
それでは、構築したグラフを使ってタスクを解いてみます。
こちらでもやってみた、指定したSVG pathがどの図形を描画しているかを当てる問題です。
まずは推論モジュールとタスクを定義。
reasoning_modules = [
"1. How could I devise an experiment to help solve that problem?",
"2. Make a list of ideas for solving this problem, and apply them one by one to the problem to see if any progress can be made.",
# "3. How could I measure progress on this problem?",
"4. How can I simplify the problem so that it is easier to solve?",
"5. What are the key assumptions underlying this problem?",
"6. What are the potential risks and drawbacks of each solution?",
"7. What are the alternative perspectives or viewpoints on this problem?",
"8. What are the long-term implications of this problem and its solutions?",
"9. How can I break down this problem into smaller, more manageable parts?",
"10. Critical Thinking: This style involves analyzing the problem from different perspectives, questioning assumptions, and evaluating the evidence or information available. It focuses on logical reasoning, evidence-based decision-making, and identifying potential biases or flaws in thinking.",
"11. Try creative thinking, generate innovative and out-of-the-box ideas to solve the problem. Explore unconventional solutions, thinking beyond traditional boundaries, and encouraging imagination and originality.",
# "12. Seek input and collaboration from others to solve the problem. Emphasize teamwork, open communication, and leveraging the diverse perspectives and expertise of a group to come up with effective solutions.",
"13. Use systems thinking: Consider the problem as part of a larger system and understanding the interconnectedness of various elements. Focuses on identifying the underlying causes, feedback loops, and interdependencies that influence the problem, and developing holistic solutions that address the system as a whole.",
"14. Use Risk Analysis: Evaluate potential risks, uncertainties, and tradeoffs associated with different solutions or approaches to a problem. Emphasize assessing the potential consequences and likelihood of success or failure, and making informed decisions based on a balanced analysis of risks and benefits.",
# "15. Use Reflective Thinking: Step back from the problem, take the time for introspection and self-reflection. Examine personal biases, assumptions, and mental models that may influence problem-solving, and being open to learning from past experiences to improve future approaches.",
"16. What is the core issue or problem that needs to be addressed?",
"17. What are the underlying causes or factors contributing to the problem?",
"18. Are there any potential solutions or strategies that have been tried before? If yes, what were the outcomes and lessons learned?",
"19. What are the potential obstacles or challenges that might arise in solving this problem?",
"20. Are there any relevant data or information that can provide insights into the problem? If yes, what data sources are available, and how can they be analyzed?",
"21. Are there any stakeholders or individuals who are directly affected by the problem? What are their perspectives and needs?",
"22. What resources (financial, human, technological, etc.) are needed to tackle the problem effectively?",
"23. How can progress or success in solving the problem be measured or evaluated?",
"24. What indicators or metrics can be used?",
"25. Is the problem a technical or practical one that requires a specific expertise or skill set? Or is it more of a conceptual or theoretical problem?",
"26. Does the problem involve a physical constraint, such as limited resources, infrastructure, or space?",
"27. Is the problem related to human behavior, such as a social, cultural, or psychological issue?",
"28. Does the problem involve decision-making or planning, where choices need to be made under uncertainty or with competing objectives?",
"29. Is the problem an analytical one that requires data analysis, modeling, or optimization techniques?",
"30. Is the problem a design challenge that requires creative solutions and innovation?",
"31. Does the problem require addressing systemic or structural issues rather than just individual instances?",
"32. Is the problem time-sensitive or urgent, requiring immediate attention and action?",
"33. What kinds of solution typically are produced for this kind of problem specification?",
"34. Given the problem specification and the current best solution, have a guess about other possible solutions."
"35. Let’s imagine the current best solution is totally wrong, what other ways are there to think about the problem specification?"
"36. What is the best way to modify this current best solution, given what you know about these kinds of problem specification?"
"37. Ignoring the current best solution, create an entirely new solution to the problem."
# "38. Let’s think step by step."
"39. Let’s make a step by step plan and implement it with good notation and explanation.",
]
reasoning_modules_str = "\n".join(reasoning_modules)
task_example = """This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L
45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a:
(A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon(H) rectangle (I) sector (J) triangle"""
では、上記で設定した内容を使ってグラフを実行。
ret = app.invoke({"task_description": task_example, "reasoning_modules": reasoning_modules_str})
結果を出力。
各ノードの結果を表示しているので長いですが、最後に回答が出力されています。
from pprint import pprint
print("Selected Modules:")
pprint(ret["selected_modules"], width=150)
print()
print("Adapted Modules:")
pprint(ret["adapted_modules"], width=150)
print()
print("Reasoning Structure:")
pprint(ret["reasoning_structure"], width=150)
print()
print("Answer:")
pprint(ret["answer"], width=150)
Selected Modules:
('1. Understanding the geometric shapes: This is crucial to recognize the pattern and features of the given path element.\n'
'2. Analyzing the path data: The description of the path (M, L commands) indicates straight line segments, which suggests a polygon.\n'
'3. Breaking down the path: Divide the path into its individual line segments to understand the sequence of vertices.\n'
'4. Counting vertices and sides: Determine the number of vertices and sides to classify the shape.\n'
'5. Comparing with known shapes: Compare the described path with the characteristics of the given options (A to J) to identify the correct shape.\n'
"6. Geometric reasoning: Apply knowledge of geometry to deduce the shape based on the path's movement.\n"
'7. Pattern recognition: Recognize the sequence of vertices as forming a regular or irregular polygon.\n'
'8. Decision-making: Choose the shape that best fits the path data.\n'
'\n'
'These modules are essential for solving the task, as they involve understanding the shape, analyzing the path, and making a logical decision based '
'on the given information.\n')
Adapted Modules:
('1. Path interpretation: Begin by examining the SVG path data, which uses M (move to) and L (line to) commands. This helps to understand the '
'starting and ending points of each line segment.\n'
'2. Geometric decomposition: Break down the path into its constituent line segments, which represent the sides of the shape. This step is vital for '
'identifying the pattern and number of sides.\n'
'3. Vertex count analysis: Count the total number of vertices in the path, as this corresponds to the number of corners in the shape. This '
'information is crucial for distinguishing between polygons with different sides.\n'
'4. Polygon type determination: Based on the number of vertices, classify the shape as a regular or irregular polygon. Regular polygons have equal '
'sides and equal angles.\n'
'5. Shape identification: Compare the observed pattern and vertex count with the characteristics of the provided options (A to J). This involves '
'checking if the path data matches the properties of each shape.\n'
'6. Geometric reasoning: Apply geometric principles to deduce the shape, considering the sequence of line segments and whether they form a closed '
'or open figure.\n'
'7. Regularity check: Assess whether the path indicates a regular polygon, such as a hexagon or heptagon, or an irregular one like a kite or '
'pentagon.\n'
'8. Final decision: After analyzing all the reasoning steps, choose the shape (A to J) that best fits the description of the SVG path element.\n')
Reasoning Structure:
('{\n'
' "Step 1: Path interpretation":\n'
' {\n'
' "Initial point": "M 55.57,80.69",\n'
' "First line segment": "L 57.38,65.80",\n'
' "Second line segment": "M 57.38,65.80 L 48.90,57.46",\n'
' "Third line segment": "M 48.90,57.46 L 45.58,47.78",\n'
' "Fourth line segment": "M 45.58,47.78 L 53.25,36.07",\n'
' "Fifth line segment": "L 66.29,48.90",\n'
' "Sixth line segment": "L 78.69,61.09",\n'
' "Seventh line segment": "L 55.57,80.69 Z" (Note: The Z command closes the path)\n'
' },\n'
'\n'
' "Step 2: Geometric decomposition":\n'
' {\n'
' "Line segments": ["55.57,80.69 to 57.38,65.80", "57.38,65.80 to 48.90,57.46", "48.90,57.46 to 45.58,47.78", "45.58,47.78 to 53.25,36.07", '
'"53.25,36.07 to 66.29,48.90", "66.29,48.90 to 78.69,61.09", "78.69,61.09 to 55.57,80.69"]\n'
' },\n'
'\n'
' "Step 3: Vertex count analysis":\n'
' {\n'
' "Total vertices": 7\n'
' },\n'
'\n'
' "Step 4: Polygon type determination":\n'
' {\n'
' "Based on 7 vertices, it\'s a polygon with 7 sides"\n'
' },\n'
'\n'
' "Step 5: Shape identification":\n'
' {\n'
' "Comparing with options: Heptagon has 7 sides"\n'
' },\n'
'\n'
' "Step 6: Geometric reasoning":\n'
' {\n'
' "The path forms a closed shape with 7 sides, indicating a heptagon"\n'
' },\n'
'\n'
' "Step 7: Regularity check":\n'
' {\n'
' "The heptagon is not specified as regular, so it\'s an irregular heptagon"\n'
' },\n'
'\n'
' "Step 8: Final decision":\n'
' {\n'
' "Shape: (B) Heptagon"\n'
' }\n'
'}\n')
Answer:
('{\n'
' "Step 1: Path interpretation": {\n'
' "Initial point": "M 55.57,80.69",\n'
' "First line segment": "L 57.38,65.80",\n'
' "Second line segment": "M 57.38,65.80 L 48.90,57.46",\n'
' "Third line segment": "M 48.90,57.46 L 45.58,47.78",\n'
' "Fourth line segment": "M 45.58,47.78 L 53.25,36.07",\n'
' "Fifth line segment": "L 66.29,48.90",\n'
' "Sixth line segment": "L 78.69,61.09",\n'
' "Seventh line segment": "L 55.57,80.69 Z"\n'
' },\n'
'\n'
' "Step 2: Geometric decomposition": {\n'
' "Line segments": ["55.57,80.69 to 57.38,65.80", "57.38,65.80 to 48.90,57.46", "48.90,57.46 to 45.58,47.78", "45.58,47.78 to 53.25,36.07", '
'"53.25,36.07 to 66.29,48.90", "66.29,48.90 to 78.69,61.09", "78.69,61.09 to 55.57,80.69"]\n'
' },\n'
'\n'
' "Step 3: Vertex count analysis": {\n'
' "Total vertices": 7\n'
' },\n'
'\n'
' "Step 4: Polygon type determination": {\n'
' "Based on 7 vertices, it\'s a polygon with 7 sides"\n'
' },\n'
'\n'
' "Step 5: Shape identification": {\n'
' "Comparing with options: Heptagon has 7 sides"\n'
' },\n'
'\n'
' "Step 6: Geometric reasoning": {\n'
' "The path forms a closed shape with 7 sides, indicating a heptagon"\n'
' },\n'
'\n'
' "Step 7: Regularity check": {\n'
' "The heptagon is not specified as regular, so it\'s an irregular heptagon"\n'
' },\n'
'\n'
' "Step 8: Final decision": {\n'
' "Shape: (B) Heptagon"\n'
' }\n'
'}\n')
ただし答えはHeptagoということで、正解です。
ただ、Reasoning Structureの時点で回答を作っていたので、Self-Discoverの挙動としてはちょっとイマイチ。使うモデルやプロンプトチューニングで改善するとは思います。
まとめ
Self-DiscoverをLangGraphでもやってみました。
最初にも記載しましたが、正直冗長になりすぎるのでLangGraphで構築するのはあまりオススメできない気はします。
LCELのChainとして作成するか、SGLangのような仕組でもっとシンプルに実装するほうが個人的にはよいかなあと。
とはいえ、いろんなパターンを知っておくのはよいですね。LangGraph楽しいし。
他にもサンプルノートブックなど、掘っていきたいと思います。