LoginSignup
5
2

BERTopicでGPT-4を使う方法

Last updated at Posted at 2023-12-16

Summary

  • BERTopicはトピックモデルの実装としてとても使いやすく,OpenAIのAPIとも連携することができるが,APIの時間当たりの利用制限のためそのままでは動かない
  • が,BERTopicのコードの一部を改修するだけで動くようになる

トピックモデル BERTopic

トピックモデルは自然言語処理のタスクの一つで,テキストデータから「トピック」を抽出します.トピック$=$日本語では「主題」ですが,トピックモデルは入力のテキストから主題を検出するものであるとも言えます.
従来のトピックモデルには,

  • LDA (Latent Dirichlet Allocation)
  • LSI (Latent Semantic Indexing)
  • PLSI (Probabilistic Latent Semantic Indexing)

といった手法がありますが,Transformerが出現して以降のメジャーな手法としてBERTopicがあります.

トピックモデルでは,入力したテキストを数値ベクトルに変換する必要がありますが,BERTopicではその工程にBERT(Sentence-BERT)使用しています.
動かし方も簡単で,以下のようにscikit-learnの文法で簡単にモデルを構築することができます.

from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups

docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data']

topic_model = BERTopic()
topics, probs = topic_model.fit_transform(docs)

また,Embeddingに使用するモデルはSentence-BERT以外にも色々選択することができ,OpenAIのGPTシリーズをAPIで使用することも可能です.

Input Code
client = openai.OpenAI(api_key="API-KEY")
model = "gpt-4"
tokenizer = tiktoken.encoding_for_model(model)
representation_model = OpenAI(
    client,
    model=model,
    chat=True,
    delay_in_seconds=60.0,
    doc_length=7500,
    tokenizer=tokenizer
)
topic_model = BERTopic(representation_model=representation_model, verbose=True)

topics, probs = topic_model.fit_transform(docs)

後述する理由により,tiktokenのTokenizerを渡しています.
またOpenAIのAPI利用制限は1分当たりで決められているので,delay_in_seconds=60.0としています.

そのままでは動かない

時間あたりのDelayも設定しましたし上記のコードで動いてほしいのですが,OpenAIのAPIに限ってはAPIの時間あたりの利用制限をクリアすることができず,実は上記のコードはエラーで落ちます.

具体的には,

RateLimitError: Error code: 429 - {'error': {'message': 'Request too large for gpt-4 in organization xxx on tokens_usage_based per min: Limit 10000, Requested 30663. Visit https://platform.openai.com/account/rate-limits to learn more.', 'type': 'tokens_usage_based', 'param': None, 'code': 'rate_limit_exceeded'}}

のように1000/mの制限を超えていると怒られることになります.
delay_in_seconds設定したのに...

エラーを追いかける

では,具体的にどこが原因で落ちているのかを確認します.
エラーメッセージによると,エラーの発生場所はbertopic/representation/_openai.pyの217行目とのことでした.
前後のコードを確認してみると,以下のようになっており,217行目でOpenAIのクライアントを使ってリクエストを投げていることがわかります.
今回は入力したトークンが多すぎることが問題なので,211行目で設定されているpromptの中身が怪しいですね.

Code-1
           # 208:230
           if self.chat:
                messages = [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt} # 211
                ]
                kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
                if self.exponential_backoff:
                    response = chat_completions_with_backoff(self.client, **kwargs)
                else:
                    response = self.client.chat.completions.create(**kwargs) # 217

                # Check whether content was actually generated
                # Adresses #1570 for potential issues with OpenAI's content filter
                if hasattr(response.choices[0].message, "content"):
                    label = response.choices[0].message.content.strip().replace("topic: ", "")
                else:
                    label = "No label returned"
            else:
                if self.exponential_backoff:
                    response = completions_with_backoff(self.client, model=self.model, prompt=prompt, **self.generator_kwargs)
                else:
                    response = self.client.completions.create(model=self.model, prompt=prompt, **self.generator_kwargs)
                label = response.choices[0].message.content.strip()

そこで,promptが設定されている箇所を見に行くと,すぐ上に

Code-2
        # 199:206
        for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
            truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
            prompt = self._create_prompt(truncated_docs, topic, topics)
            self.prompts_.append(prompt)

            # Delay
            if self.delay_in_seconds:
                time.sleep(self.delay_in_seconds)

という箇所が見つかりました.self._create_prompt

Code-3
    def _create_prompt(self, docs, topic, topics):
        keywords = list(zip(*topics[topic]))[0]

        # Use the Default Chat Prompt
        if self.prompt == DEFAULT_CHAT_PROMPT or self.prompt == DEFAULT_PROMPT:
            prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
            prompt = self._replace_documents(prompt, docs)

        # Use a custom prompt that leverages keywords, documents or both using
        # custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
        else:
            prompt = self.prompt
            if "[KEYWORDS]" in prompt:
                prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
            if "[DOCUMENTS]" in prompt:
                prompt = self._replace_documents(prompt, docs)

        return prompt

のように定義されているので,入力のながさはtruncated_docsに依存しそうです.
そこで,まずtruncate_document()を確認します.
こちらは bertopic.representation._utilsで定義されています.

Code-4
def truncate_document(topic_model, doc_length, tokenizer, document: str):
    if doc_length is not None:
        if tokenizer == "char":
            truncated_document = document[:doc_length]
        elif tokenizer == "whitespace":
            truncated_document = " ".join(document.split()[:doc_length])
        elif tokenizer == "vectorizer":
            tokenizer = topic_model.vectorizer_model.build_tokenizer()
            truncated_document = " ".join(tokenizer(document)[:doc_length])
        elif hasattr(tokenizer, 'encode') and hasattr(tokenizer, 'decode'):
            encoded_document = tokenizer.encode(document)
            truncated_document = tokenizer.decode(encoded_document[:doc_length])
        return truncated_document
    return document

今回はモデルのパラメータとしてtiktokenのTokenizerを渡しているので,最後のelifで入力テキストをdoc_length分のトークン列に切り出しているようです.
よって,入力テキストをループしている部分で各ループごとに渡されているdocの長さとself.doc_lengthでAPIへの入力プロンプトの長さが決まりそうです.

そこで,さらに調べてみたところ,このループはバッチ化されており,今回の実験では入力テキスト数が911であったところ,ループ回数は21回でした.したがって,ループ1回当たり43程度のテキストが含まれていることになります.
(そりゃAPIを限界突破してしまうわけです)
そこで,APIの入力制限に収まるようにコードを書き換えます.
まず,GPT-4に入力できるトークン数は8192なので,それに収まるようにself.doc_lengthの値を設定します.
今回のコードでは,出力も考慮して,入力トークン数の最大値を7500と設定しています.
そして,Code-2を以下のように書き換えます.

Code-5
        for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
            - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
            + truncated_docs = [truncate_document(topic_model, int(self.doc_length / len(docs)), self.tokenizer, doc) for doc in docs]
            prompt = self._create_prompt(truncated_docs, topic, topics)
            self.prompts_.append(prompt)

全体が最大トークン数に収まるように,入力した文章一つ当たりから切り出す長さをバッチあたりの文章量で割って調整しました. → int(self.doc_length / len(docs))

長い文章の場合は入力の最初の方しか拾えなくなるため,精度には多少の影響がありそうですが,ひとまずこれで動かすことができるようになります.

また,今回は入力文章の長さだけを変更しましたが,バッチサイズを小さくすることで文章あたりの処理量を増やすことができるので,精度がイマイチでもっと長いコンテキストが必要な場合にはさらにバッチサイズを調整するようにコードを修正してみると良いかもしれません.

BERTopic × GPT-4を動かす方法まとめ

  • モデル構築時にtiktokenのTokenizerを渡す
  • doc_lengthをAPIの利用制限内に設定する
  • delay_in_secondsは1分に設定しておく
  • APIに投げる1回当たりのトークン数を制御するために,コードを1行変更する
    • - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
    • + truncated_docs = [truncate_document(topic_model, int(self.doc_length / len(docs)), self.tokenizer, doc) for doc in docs]
5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2