はじめに
LLMを学習していると、たいていはフレームワークに用意されているAPIを使用してLLMと接続します。
ところが、とあるフレームワークでAPIが用意されていなくて自分で作る必要があって、特にLangchainのカスタムChatModelを作成できないかと試行錯誤したという経緯がありました。
へ~こんな風に内部では動いているのか~と感じてもらったり、Custom ChatModelを作成いる際の参考にしてもらえれば幸いです。
Ollama APIのしくみ
まずは手始めにローカルPCでollamaを使用して、API Callしてみたいと思います。
今回は既存のものは使用せずにurlを呼び出しています。(既存のAPIがないことを前提にしているため)
前提としてollamaにはllama3.2をいれています。
Ollamaでは大きく分けるとgenerateとchatの二つがあります。
-
generate
単発の回答。プロンプトを投げて、その回答を受け取るというものです。
localhost:11434/api/generate が基本的なURLです -
chat
対話形式。これまでのメッセージリストを投げて、最新の回答を受け取るというものです。
localhost:11434/api/chat が基本的なURLです
Ollama API -- generate--
まずはサンプルコード
import json
import asyncio
import aiohttp
# generate は単発の質問
# promptで質問を送ります
# APIサーバーのURL
API_SERVER_URL = "http://localhost:11434/api/generate"
# 非同期でメッセージを送信する関数
async def send_message_async(prompt):
headers = {
'Content-Type': 'application/json' # リクエストのヘッダー
}
json_data = {
"model": "llama3.2", # 使用するモデル
"prompt": prompt, # ユーザーからのプロンプト内容
"stream": True # ストリーミングを有効にする
}
# 非同期でHTTPセッションを作成
async with aiohttp.ClientSession() as session:
# 非同期でPOSTリクエストを送信
async with session.post(API_SERVER_URL, headers=headers, json=json_data) as response:
# レスポンスの内容を非同期で1行ずつ処理
async for line in response.content:
yield line.decode('utf-8') # レスポンスの行をデコードして返す
# 非同期でレスポンスをフォーマットする関数
async def format_response_async(prompt):
# 非同期でメッセージを送信し、レスポンスを受け取る
async for response_text in send_message_async(prompt):
response_json = json.loads(response_text) # レスポンスをJSONとして読み込む
# レスポンスの内容を1文字ずつ出力
for char in response_json['response']:
print(char, end='', flush=True) # 文字を出力してフラッシュする
if __name__ == "__main__":
# 送信するプロンプト
prompt = "アジャイルついて教えてください。"
# 非同期関数を実行
asyncio.run(format_response_async(prompt))
注目はjson_dataの"prompt"です。llmに送信するものは基本的にプロンプト一つとなります。
Ollama API -- chat--
次はchatの場合のサンプルコード
import json
import asyncio
import aiohttp
# chatは文脈を踏まえた対話
# messages listで複数のメッセージ(履歴)を送ることで対話を行います
# APIサーバーのURL
API_SERVER_URL = "http://localhost:11434/api/chat"
# 非同期でメッセージを送信する関数
async def send_message_async(message):
headers = {
'Content-Type': 'application/json' # リクエストのヘッダー
}
json_data = {
"model": "llama3.2", # 使用するモデル
"messages": [{
"role": "user", # メッセージの役割(ユーザー)
"content": message # ユーザーからのメッセージ内容
}]
}
# 非同期でHTTPセッションを作成
async with aiohttp.ClientSession() as session:
# 非同期でPOSTリクエストを送信
async with session.post(API_SERVER_URL, headers=headers, json=json_data) as response:
# レスポンスの内容を非同期で1行ずつ処理
async for line in response.content:
yield line.decode('utf-8') # レスポンスの行をデコードして返す
# 非同期でレスポンスをフォーマットする関数
async def format_response_async(message):
# 非同期でメッセージを送信し、レスポンスを受け取る
async for response_text in send_message_async(message):
response_json = json.loads(response_text) # レスポンスをJSONとして読み込む
# レスポンスの内容を1文字ずつ出力
for char in response_json['message']['content']:
print(char, end='', flush=True) # 文字を出力してフラッシュする
if __name__ == "__main__":
# 送信するメッセージ
message = "アジャイルついて教えてください。"
# 非同期関数を実行
asyncio.run(format_response_async(message))
json_dataのmessagesがリストになっていますね。ここにこれまでの対話をリストにして入れておくわけです。つまり、chatというのは、サーバー側はステートレスなので、状態はクライアント側で記録しているということです。
LangChainのCustom ChatModelを作ってみる
さて、ここからが本題です。
ollamaのchat APIを使用して、LangChainのChatModelを作ってみたいと思います。
from langchain_community.llms.ollama を見れば本家の実装パターンが分かりますが、これを追っていくのは大変です。ということで、さきほどのollama chatを利用してシンプルにしたものがこちらになります。
import asyncio
import aiohttp
import requests
import json
import time
from typing import Any, Dict, Iterator, List, Optional, AsyncIterator
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
class OllamaChatModel(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = OllamaChatModel(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
model_name: str = Field(alias="model")
end_point: str = Field(alias="end_point")
"""The name of the model"""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.
This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
print("Generating")
# Replace this with actual logic to generate a response from a list of messages.
last_message = messages[-1]
message_contents = []
for msg in messages:
if isinstance(msg, HumanMessage):
message_contents.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
message_contents.append({"role": "assistant", "content": msg.content})
headers = {
'Content-Type': 'application/json'
}
json_data = {
"model": self.model_name,
"messages": message_contents,
"stream": False
}
response = requests.post(self.end_point, headers=headers, json=json_data)
messages = []
for line in response.text.splitlines():
response_json = json.loads(line)
messages.append(response_json['message']['content'])
combined_message = ''.join(messages)
meta_data = response.text.splitlines()[-1]
meta_json = json.loads(meta_data)
ct_input_tokens = meta_json.get("prompt_eval_count", 0)
ct_output_tokens = meta_json.get("eval_count", 0)
message = AIMessage(
content=combined_message,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.
This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
print("Streaming")
headers = {
'Content-Type': 'application/json' # リクエストのヘッダー
}
message_contents = [
{"role": "user", "content": msg.content} if isinstance(msg, HumanMessage) else {"role": "assistant", "content": msg.content}
for msg in messages
]
json_data = {
"model": self.model_name,
"messages": message_contents,
"stream": True
}
response = requests.post(self.end_point, headers=headers, json=json_data)
for line in response.iter_lines(decode_unicode=True):
response_json = json.loads(line)
content = response_json['message']['content']
time.sleep(0.01)
usage_metadata = UsageMetadata(
{
"input_tokens": len(content),
"output_tokens": 1,
"total_tokens": len(content) + 1,
}
)
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=content,
usage_metadata=usage_metadata)
)
if run_manager:
run_manager.on_llm_new_token(content, chunk=chunk)
yield chunk
# async def _astream(
# self,
# messages: List[BaseMessage],
# stop: Optional[List[str]] = None,
# run_manager: Optional[CallbackManagerForLLMRun] = None,
# **kwargs: Any,
# ) -> AsyncIterator[ChatGenerationChunk]:
# """Async stream the output of the model.
# This method should be implemented if the model can generate output
# in a streaming fashion. If the model does not support streaming,
# do not implement it. In that case streaming requests will be automatically
# handled by the _generate method.
# Args:
# messages: the prompt composed of a list of messages.
# stop: a list of strings on which the model should stop generating.
# If generation stops due to a stop token, the stop token itself
# SHOULD BE INCLUDED as part of the output. This is not enforced
# across models right now, but it's a good practice to follow since
# it makes it much easier to parse the output of the model
# downstream and understand why generation stopped.
# run_manager: A run manager with callbacks for the LLM.
# """
# print("Async Streaming")
# async def send_message_async():
# headers = {
# 'Content-Type': 'application/json' # リクエストのヘッダー
# }
# message_contents = [
# {"role": "user", "content": msg.content} if isinstance(msg, HumanMessage) else {"role": "assistant", "content": msg.content}
# for msg in messages
# ]
# json_data = {
# "model": self.model_name,
# "messages": message_contents,
# "stream": True
# }
# async with aiohttp.ClientSession() as session:
# async with session.post(self.end_point, headers=headers, json=json_data) as response:
# async for line in response.content:
# yield line.decode('utf-8')
# async for response_text in send_message_async():
# response_json = json.loads(response_text)
# # レスポンスの内容を1文字ずつ出力
# content = response_json['message']['content']
# usage_metadata = UsageMetadata(
# {
# "input_tokens": len(content),
# "output_tokens": 1,
# "total_tokens": len(content) + 1,
# }
# )
# chunk = ChatGenerationChunk(
# message=AIMessageChunk(content=content, usage_metadata=usage_metadata)
# )
# if run_manager:
# await run_manager.on_llm_new_token(content, chunk=chunk)
# yield chunk
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
"model_name": self.model_name,
"end_point": self.end_point,
}
async def main():
API_CHAT = "http://localhost:11434/api/chat"
model = OllamaChatModel(model="llama3.2", end_point=API_CHAT)
messages = [
HumanMessage(content="こんにちは"),
AIMessage(content="こんにちは!何が聞きたいですか?"),
HumanMessage(content="アジャイルについて300文字で教えてください"),
]
print("*"*10, "generateを呼び出す", "*"*10)
response = model.invoke(messages)
print(response.content, end='', flush=True)
print("*"*10, "streamを呼び出す", "*"*10)
for chunk in model.stream(messages):
print(chunk.content, end='', flush=True)
print("*"*10, "astreamを呼び出す", "*"*10)
async for chunk in model.astream(messages):
print(chunk.content, end='', flush=True)
if __name__ == "__main__":
asyncio.run(main())
参考としたのは以下のサイト。
これを参考にgenerate, stream, astreamの3パターンを作成したのがさきほどのサンプルコードとなります。
async def main():を見ていただけたいのですが、
- OllamaChatModelを作成
- 会話を入力
- generate, stream, astremをそれぞれ呼び出す
というフローとなっています。
LangChainのChatModelはBaseChatModelを継承して、必要なメソッドを自分で作ってねというものになっているので、それに応じて作成しました。
async def _astreamがコメントになって無効になっていますが、LangChainのサイトにも説明があるのですが、実装していない場合はstreamが使用されると書かれています。
なので、async def _astreamを有効にした場合は、これが使われますがstreamの実装だけで良いというのはありがたいです。
model.invokeだと _generate が使用され、model.stream, model.astreamだと _stream が使用されるようになっています。
LangChainのサイトだと、LLMの呼び出しが書かれていないので、そこは自分で書いてねとあって、その書き方教えてほしいんだけどと叫びたくなります。
なので、それを既存のAPIを使用せずにスクラッチで実際にやってみたのが今回の記事となります。
補足
こういった知らないことを学習するのに時間がかかるのは当然なのですが、今回はGithub Copilotを使用してサンプルコードの作成を手伝ってもらいました。自分で最終的な動作などを見るのは必要ですが、驚愕するほどCopilotは学習コスト削減になります。本当にAIの力を感じる今日この頃です💦