はじめに
生成AIで、いろんな種類のプロンプトとか質問を試したい時、非同期で呼んで並行処理したいときってありますよね。本記事ではIBM watsonx.ai APIをPython使って非同期で呼ぶ方法を紹介します。
他のLLM
OpenAIでは、AsyncOpenAI
というものがあるみたいなので、それを呼べばいいみたいです。
公式SDKにない!?
ドキュメント見てたんですが、2023/12/24現在、まだなさそうです😢
1. SDKの更新を待つ
watsonx.aiにはTech Previewがあって、そっちのSDKのドキュメントを見ると、非同期で呼ぶ方法がありました。なので、しばらく待てばwatsonx.aiのSDKでも非同期で呼ぶ方法が実装されると思います。
2. LangChainを使う
LangChainを使うなら、Async APIが準備されているので、そちらを使えば良さそうです。
Langchain使うとき、どうするんやろと調べてたら、watsonx.aiもしれっと追加されてることに気づきました😲 知らなんだ。
3. asyncioとaiohttpを使う
非同期でAPI callするパッケージを利用して呼べます。
以下にコードを示します。色々、微妙なところあるかもですが、検証用と思って見てもらったらいいかと思います。
import asyncio
import aiohttp
import requests
class watsonx:
def __init__(self):
self.access_token = self.get_access_token()
@staticmethod
def get_access_token():
url = "https://iam.cloud.ibm.com/identity/token"
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
data = {
'grant_type': 'urn:ibm:params:oauth:grant-type:apikey',
'apikey': 'xxx'
}
response = requests.post(url, headers=headers, data=data)
response_data = response.json()
access_token = response_data.get('access_token')
return access_token
@staticmethod
async def generate(session, prompt):
url = 'https://us-south.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=2023-05-29'
json = {
"model_id": "meta-llama/llama-2-70b-chat",
"input": prompt,
"parameters": {
"decoding_method": "greedy",
"max_new_tokens": 1500,
"min_new_tokens": 1,
"stop_sequences": []
},
"project_id": "xxxxxx"
}
try:
async with session.post(url=url, json=json) as response:
response.raise_for_status() # HTTPエラーをチェック
return await response.json()
except aiohttp.ClientError as e:
print(f"HTTPクライアントエラー: {e}")
except Exception as e:
print(f"予期せぬエラー: {e}")
async def generate_answers(self, prompts):
headers={
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': 'Bearer {}'.format(self.access_token)
}
async with aiohttp.ClientSession(headers=headers) as session:
tasks = [self.generate(session, prompt) for prompt in prompts]
responses = await asyncio.gather(*tasks, return_exceptions=True)
results = []
for response in responses:
if isinstance(response, Exception):
print(f"エラー: {response}")
else:
if response is None:
results.append('')
else:
result = response['results'][0]['generated_text']
results.append(result)
return results
コンストラクタでget_access_token()
を実行してapi keyからaccess tokenを生成してます。
access tokenは期限があるので、そこらはちゃんとハンドリング必要です。が、上記コードではサボってます。
generate()
で非同期でAPIを呼んでいます。
generate_answers()
で、並行して呼び出してます。
少し細かいですが、generate()
の方でエラーをprintだけして握りつぶしてるので、generate_answers()
の方で、isinstance(response, Exception)
のチェックしてるの意味ない気がします。エラーハンドリングしたいお気持ち表明で残してます。
asyncio.gather(*tasks, return_exceptions=True)
のところで、結果が全部かえってくるのを待ってます。例外が発生しても我慢して全部の処理結果を受け取ってます。
response
がNone
(generate()
がエラーのときprintだけしてreturnしてない)だったら、空文字入れてます。
こんな感じで、非同期で呼べます。Pythoの非同期処理、TypeScriptと似てるなーと思った次第です。
おわりに
ひとまず、やりたいことができたのでよかったです。後から知ったLangChain使う方法もそのうち試してみたいなと思います。