はじめに
LLMで「コンテキストに確実に含まれる文章」を抽出する方法を検討しました。
背景
LLMにコンテキストと指示を与えて、コンテキストの一部を抜き出すタスクをさせると、普通はLLMは該当箇所を文章生成することで抽出するので、ハルシネーションが起こり、コンテキスト中の表現と厳密に一致しない文章を生成するおそれがあります。
そこで、抽出対象の文章を生成させるのではなく、コンテキストを適当な長さに区切った文章リストをIDとともに与え、LLMにはIDのみを生成させるというアイデアを試してみます。IDを文章に戻す作業はルールベースで実施することで、コンテキストに確実に含まれる文章を返すことができます。
関連研究・ライブラリ
文章ではなくIDを生成させるアイデアはRankGPTの実装から着想しています(RankGPTが起源というわけでもないと思いますが)。
RankGPTはLLMを使って検索結果のリランクをするための実装です。検索結果のリストを与えると、内部的に各結果にIDを与えて、LLMに関連度順のID列を生成させ、それを検索結果文章になおして返す、という処理になっています。
つい最近登場したlangextractというライブラリでは、長文コンテキストからの抽出とソースグラウンディング(抽出文がソース中のどの位置にあるかを特定すること)の機能があるとされています。また試してみたいです。
LLMではなくQAタスクに特化したBERTベースのモデルなどを使えば、抽出対象の出現位置を直接予測することが可能です。ただしLLMより指示の柔軟性は下がります。
実装
いろんな書き方があると思いますが、例えば以下のようなコードで、期待する結果が得られました。
- contextを文章に分割(split_to_sentences)
- idとともにLLMに与え、ExtractedSentenceIdsスキーマでidを得る
- idを文章になおして返す
import litellm
from pydantic import BaseModel, Field
import re
class ExtractedSentenceIds(BaseModel):
"""Extracted Sentence IDs"""
ids: list[int] = Field(..., description="IDs of the extracted sentences.")
class Span(BaseModel):
"""Span of text"""
start: int = Field(..., description="Start index of the span.")
end: int = Field(..., description="End index of the span.")
text: str = Field(..., description="Text content of the span.")
def split_to_sentences(text: str, separator: list[str] = ["。", "\n"]) -> list[str]:
"""Split text into sentences."""
# 各セパレータをエスケープして正規表現パターンを作成
sep_pattern = '|'.join([re.escape(sep) for sep in separator])
# セパレータが1つ以上連続した部分で区切る
pattern = f'(.*?(?:{sep_pattern})+)'
sentences = re.findall(pattern, text, flags=re.DOTALL)
# 末尾にセパレータがない場合の残り
last = re.sub(f'({sep_pattern})+$', '', text)
if last and not any([last.endswith(sep) for sep in separator]) and last not in sentences:
sentences.append(last)
return sentences
def extract_text(question: str, context:str) -> list[Span]:
"""Extract text from the context based on the question."""
# Split context into sentences
sentences = split_to_sentences(context)
sentence_spans = []
start_index = 0
for sentence in sentences:
end_index = start_index + len(sentence)
sentence_spans.append(Span(start=start_index, end=end_index, text=sentence))
start_index = end_index + 1
sentence_with_ids = "\n".join([f"{i}: {sentence}" for i, sentence in enumerate(sentences)])
prompt = f"""
あなたは質問に対して、与えられたコンテキストから関連する文を抽出するタスクを行います。
質問: {question}
コンテキスト:
{sentence_with_ids}
関連するテキストのIDをすべて返してください。
"""
messages = [
{"role": "system", "content": prompt}
]
response = litellm.completion(
messages=messages,
model="gpt-4.1-mini",
response_format=ExtractedSentenceIds
)
extracted_ids = ExtractedSentenceIds.model_validate_json(response['choices'][0]['message']['content']).ids
extracted_spans = [sentence_spans[int(id)] for id in extracted_ids]
return extracted_spans
if __name__ == "__main__":
context = """日本の山の高さトップ5について説明します。
1位は富士山で、標高3,776メートルと日本で最も高い山です。
2位は北岳で、標高3,193メートルを誇ります。
3位は奥穂高岳で、標高3,190メートルです。
4位は間ノ岳で、標高3,189メートルとなっています。
5位は槍ヶ岳で、標高3,180メートルです。
これらの山々は日本アルプスに位置し、登山者に人気があります。"""
question = "日本の山の高さトップ5を教えてください。"
extracted_spans = extract_text(question, context)
print(question, extracted_spans)
question = "日本の1番目と3番目に高い山は?"
extracted_spans = extract_text(question, context)
print(question, extracted_spans)
% uv run litellm_span.py
日本の山の高さトップ5を教えてください。 [Span(start=23, end=55, text='1位は富士山で、標高3,776メートルと日本で最も高い山です。\n'), Span(start=56, end=81, text='2位は北岳で、標高3,193メートルを誇ります。\n'), Span(start=82, end=106, text='3位は奥穂高岳で、標高3,190メートルです。\n'), Span(start=107, end=135, text='4位は間ノ岳で、標高3,189メートルとなっています。\n'), Span(start=136, end=159, text='5位は槍ヶ岳で、標高3,180メートルです。\n')]
日本の1番目と3番目に高い山は? [Span(start=23, end=55, text='1位は富士山で、標高3,776メートルと日本で最も高い山です。\n'), Span(start=82, end=106, text='3位は奥穂高岳で、標高3,190メートルです。\n')]