6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ChatGPTのリクエストを並列化して時短しよう!

Last updated at Posted at 2023-07-30

はじめに

最近、OpenAIのAPIを使って開発を行うことが多いので、今まで使ってきて便利だったものをまとめようと思います。!また、役立つ記事も添付させていただいてます。ぜひ読んでいってください:eyes:

こんな方におすすめ

  • テストケースを1件ずつ実行している方
  • トークン数やリクエスト数の上限を超えて怒られた経験のある方

例として、簡単なQ&AをLangChain実装します。

1. 初期設定 & 準備

以下のようなテーブル形式を想定します。質問に対して各データを参照して答えるプログラムを作ります!

下のプログラムでデータの読み込み、モデル選択、max_tokenの初期設定を行います。

import pandas as pd
import openai
from tqdm.notebook import tqdm
import os
import os.path as osp
import re
import asyncio
import time
import tiktoken
from langchain import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

# initialize
tqdm.pandas()
DATADIR = '../data/'
openai.api_key = os.environ['OPENAI_KEY']
model = 'gpt-3.5-turbo-0613'

# dataload
df = pd.read_csv(osp.join(DATADIR, 'test.csv'))
document = df['data'].tolist() # 並列化用にリスト化
question = df['question'].tolist() # 並列化用にリスト化

# model token
token_keys = ['gpt-3.5-turbo', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-16k-0613', 'gpt-4', 'gpt-4-0613']
token_once = [4000, 4000, 16000, 16000, 16000, 8000, 8000] # 一度のリクエストで送れるtoken数
token_min = [90000, 90000, 180000, 180000, 40000, 40000] # 1分間あたりに送れるtoken数

once_token = {key: value for key, value in zip(token_keys, token_once)}
min_token = {key: value for key, value in zip(token_keys, token_min)}

1分間あたりに送れるトークン数はデフォルトだともっと少ないと思うので、下の記事を参照して設定してください!他に役立つ記事添付します。

token / min、request / min(1分あたりに送れるトークン数およびリクエスト数)

token / request(一度のリクエストで送れるトークン数)

モデル一覧

2. トークン数を事前にカウント

tiktokenというライブラリを用いてトークン数をカウントします。
なお、今回dataカラムのトークン数がquestionカラムのトークン数より多い(data token nums >> question token nums)と仮定し、トークン数が上限を超える場合はdataカラムのテキストを一部削除するようにプログラムを作成します。

encoding = tiktoken.encoding_for_model(model) # encoder定義

# 各レコードのトークン数をカウント
df['token'] = df.apply(lambda x: len(encoding.encode(x['data']))+len(encoding.encode(x['question'])), axis=1)
max_token = df['token'].max()

# 一度に送れるトークン数がオーバーしている場合
if max_token > once_token[model]:
    
    # 何トークン減らすか(-1で初期化)
    df['diff'] = -1
    df['diff'] = df['token'].apply(lambda x: -1 if (x-max_token) <= 0 else (x-max_token))
    
    # 一度エンコードして一部削除したあとデコード
    df['encode_data'] = df['data'].apply(lambda x:encoding.encode(x))
    df['re_data'] = df.apply(lambda x:encoding.decode(x['encode_data'][:-x['diff']]), axis=1)
    
    # 再度初期化
    data = df['re_data'].tolist()
    max_token = once_token[model]

# 一度に送るリクエスト数を計算
step = 1
if 'gpt-3.5' in model:
    step = int(min_token[model]/max_token) if min_token[model]/max_token < 3500 else 3490
elif 'gpt-4' in model:
    step = int(min_token[model]/max_token) if min_token[model]/max_token < 200 else 190

先程と同様、1分あたりに送れるリクエスト数はデフォルトだともっと少ないので下の記事を参照して設定してください!

request / min

tiktokenについて下の記事が分かりやすいです

3. 並列処理

LangChainを用いることで簡単に実装できます。
{}を用いることで変数を代入できます。(get_template関数の{data}, {question}部分に該当)

def get_template():
    """テンプレート関数"""

    # machine
    machine_template = """
    Briefly answer the following questions from the information given.

    {data}

    NOTES : 
    * Response must be Japanese.
    * Please return null for non-existent information.
    * If you want to output the table, return only markdown format.
    """
    # human
    human_template = 'Question:{question}'
    # setting
    system_message_prompt = SystemMessagePromptTemplate.from_template(machine_template)
    human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
    chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])    

    return chat_prompt

async def get_answer(chain, _data, _question):
    """質問応答関数"""
    try: 
        resp = await asyncio.wait_for(chain.arun({'data': _data, 'question':_question}))
        resp = re.sub(r'\s', '', resp)
    except Exception as e: # 念のための例外処理
        print(e)
        resp = None
        
    return resp

async def generate_concurrently(data_list, question_list, model):
    """非同期関数"""
    # モデル定義
    chat = ChatOpenAI(temperature=0, model_name=model, request_timeout=240)
    
    # プロンプト設定
    prompt = which_template(model, want_key=want_key)
    chain = LLMChain(llm=chat, prompt=prompt)

    # 並列処理
    tasks = [get_answer(_data, _question, want_key) for _data, _question in zip(data_list, question_list)]
    return await asyncio.gather(*tasks)

# ここで関数を実行させている
gpt_generate = []
for i in tqdm(range(0, len(df), step)): # 一度に送るリクエスト数
    start = time.time()
    tmp_list = await generate_concurrently(data[i:i+step], question[i:i+step], model)
    end = time.time()
    diff = end - start
    gpt_generate.extend(tmp_list)
    # もし1分未満で終わっていたら上限を超えてしまうので待機する
    if diff < 60:
        time.sleep(60-int(diff))

# 最後にdataframeに格納する
df['answer'] = gpt_generate
# df.to_csv(osp.join(DATADIR, 'output.csv'), index=False)

参考にした記事を添付します。それぞれサンプルコードを一度試すことをおすすめします!

並列処理については下の記事へ

プロンプトについては下の記事へ

プロンプトエンジニアリングについては下の記事へ

4. おわりに

今回はChatGPTの並列処理について記事にしました。おすすめの記事や技術がある方は、ぜひコメントしていただけると嬉しいです!

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?