Summary
- BERTopicはトピックモデルの実装としてとても使いやすく,OpenAIのAPIとも連携することができるが,APIの時間当たりの利用制限のためそのままでは動かない
- が,BERTopicのコードの一部を改修するだけで動くようになる
トピックモデル BERTopic
トピックモデルは自然言語処理のタスクの一つで,テキストデータから「トピック」を抽出します.トピック$=$日本語では「主題」ですが,トピックモデルは入力のテキストから主題を検出するものであるとも言えます.
従来のトピックモデルには,
- LDA (Latent Dirichlet Allocation)
- LSI (Latent Semantic Indexing)
- PLSI (Probabilistic Latent Semantic Indexing)
といった手法がありますが,Transformerが出現して以降のメジャーな手法としてBERTopicがあります.
トピックモデルでは,入力したテキストを数値ベクトルに変換する必要がありますが,BERTopicではその工程にBERT(Sentence-BERT)使用しています.
動かし方も簡単で,以下のようにscikit-learnの文法で簡単にモデルを構築することができます.
from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
topic_model = BERTopic()
topics, probs = topic_model.fit_transform(docs)
また,Embeddingに使用するモデルはSentence-BERT以外にも色々選択することができ,OpenAIのGPTシリーズをAPIで使用することも可能です.
Input Code
client = openai.OpenAI(api_key="API-KEY")
model = "gpt-4"
tokenizer = tiktoken.encoding_for_model(model)
representation_model = OpenAI(
client,
model=model,
chat=True,
delay_in_seconds=60.0,
doc_length=7500,
tokenizer=tokenizer
)
topic_model = BERTopic(representation_model=representation_model, verbose=True)
topics, probs = topic_model.fit_transform(docs)
後述する理由により,tiktoken
のTokenizerを渡しています.
またOpenAIのAPI利用制限は1分当たりで決められているので,delay_in_seconds=60.0
としています.
そのままでは動かない
時間あたりのDelayも設定しましたし上記のコードで動いてほしいのですが,OpenAIのAPIに限ってはAPIの時間あたりの利用制限をクリアすることができず,実は上記のコードはエラーで落ちます.
具体的には,
RateLimitError: Error code: 429 - {'error': {'message': 'Request too large for gpt-4 in organization xxx on tokens_usage_based per min: Limit 10000, Requested 30663. Visit https://platform.openai.com/account/rate-limits to learn more.', 'type': 'tokens_usage_based', 'param': None, 'code': 'rate_limit_exceeded'}}
のように1000/mの制限を超えていると怒られることになります.
delay_in_seconds
設定したのに...
エラーを追いかける
では,具体的にどこが原因で落ちているのかを確認します.
エラーメッセージによると,エラーの発生場所はbertopic/representation/_openai.py
の217行目とのことでした.
前後のコードを確認してみると,以下のようになっており,217行目でOpenAIのクライアントを使ってリクエストを投げていることがわかります.
今回は入力したトークンが多すぎることが問題なので,211行目で設定されているprompt
の中身が怪しいですね.
Code-1
# 208:230
if self.chat:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt} # 211
]
kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
if self.exponential_backoff:
response = chat_completions_with_backoff(self.client, **kwargs)
else:
response = self.client.chat.completions.create(**kwargs) # 217
# Check whether content was actually generated
# Adresses #1570 for potential issues with OpenAI's content filter
if hasattr(response.choices[0].message, "content"):
label = response.choices[0].message.content.strip().replace("topic: ", "")
else:
label = "No label returned"
else:
if self.exponential_backoff:
response = completions_with_backoff(self.client, model=self.model, prompt=prompt, **self.generator_kwargs)
else:
response = self.client.completions.create(model=self.model, prompt=prompt, **self.generator_kwargs)
label = response.choices[0].message.content.strip()
そこで,prompt
が設定されている箇所を見に行くと,すぐ上に
Code-2
# 199:206
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)
という箇所が見つかりました.self._create_prompt
は
Code-3
def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]
# Use the Default Chat Prompt
if self.prompt == DEFAULT_CHAT_PROMPT or self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)
# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)
return prompt
のように定義されているので,入力のながさはtruncated_docs
に依存しそうです.
そこで,まずtruncate_document()
を確認します.
こちらは bertopic.representation._utils
で定義されています.
Code-4
def truncate_document(topic_model, doc_length, tokenizer, document: str):
if doc_length is not None:
if tokenizer == "char":
truncated_document = document[:doc_length]
elif tokenizer == "whitespace":
truncated_document = " ".join(document.split()[:doc_length])
elif tokenizer == "vectorizer":
tokenizer = topic_model.vectorizer_model.build_tokenizer()
truncated_document = " ".join(tokenizer(document)[:doc_length])
elif hasattr(tokenizer, 'encode') and hasattr(tokenizer, 'decode'):
encoded_document = tokenizer.encode(document)
truncated_document = tokenizer.decode(encoded_document[:doc_length])
return truncated_document
return document
今回はモデルのパラメータとしてtiktoken
のTokenizerを渡しているので,最後のelif
で入力テキストをdoc_length
分のトークン列に切り出しているようです.
よって,入力テキストをループしている部分で各ループごとに渡されているdoc
の長さとself.doc_length
でAPIへの入力プロンプトの長さが決まりそうです.
そこで,さらに調べてみたところ,このループはバッチ化されており,今回の実験では入力テキスト数が911であったところ,ループ回数は21回でした.したがって,ループ1回当たり43程度のテキストが含まれていることになります.
(そりゃAPIを限界突破してしまうわけです)
そこで,APIの入力制限に収まるようにコードを書き換えます.
まず,GPT-4に入力できるトークン数は8192なので,それに収まるようにself.doc_length
の値を設定します.
今回のコードでは,出力も考慮して,入力トークン数の最大値を7500
と設定しています.
そして,Code-2を以下のように書き換えます.
Code-5
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
- truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
+ truncated_docs = [truncate_document(topic_model, int(self.doc_length / len(docs)), self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
全体が最大トークン数に収まるように,入力した文章一つ当たりから切り出す長さをバッチあたりの文章量で割って調整しました. → int(self.doc_length / len(docs))
長い文章の場合は入力の最初の方しか拾えなくなるため,精度には多少の影響がありそうですが,ひとまずこれで動かすことができるようになります.
また,今回は入力文章の長さだけを変更しましたが,バッチサイズを小さくすることで文章あたりの処理量を増やすことができるので,精度がイマイチでもっと長いコンテキストが必要な場合にはさらにバッチサイズを調整するようにコードを修正してみると良いかもしれません.
BERTopic × GPT-4を動かす方法まとめ
- モデル構築時に
tiktoken
のTokenizerを渡す - doc_lengthをAPIの利用制限内に設定する
- delay_in_secondsは1分に設定しておく
- APIに投げる1回当たりのトークン数を制御するために,コードを1行変更する
- truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
+ truncated_docs = [truncate_document(topic_model, int(self.doc_length / len(docs)), self.tokenizer, doc) for doc in docs]