はじめに
LLMは確率的なモデルであるため、同じ入力でも異なる出力を生成することがあります。この性質を活用して複数回推論を行い、最良の結果を選ぶことで、推論性能を向上させることができます。
この記事では、LLMが出力したテキストのリストから最も良いテキストを選択する方法について紹介します。
出力テキストが特定のカテゴリー候補からサンプリングされる場合とされない場合について、それぞれ有力なテキスト選択方法を紹介します。
出力テキストが特定のカテゴリー候補からサンプリングされる場合
出力テキストが特定のカテゴリー候補からサンプリングされる場合は、「Self-Consistency」が有力です。これは、同一のタスクに対して複数回推論を行い、最も頻度が多い出力を最終的な出力とするものです。
アルゴリズムの詳細については、下記の論文"Self-Consistency Improves Chain of Thought Reasoning in Language Models"をご確認ください。
Self-Consistencyは、数学オリンピックの問題をAIに解かせるコンペである『Kaggle AI Mathematical Olympiad - Progress Prize 1』のソリューションに採用されるなど、LLMの推論品質を向上させるための方法として広く知られています。
出力テキストが特定のカテゴリー候補からサンプリング"されない"場合
こちらが本記事の本題です。多くのケースにおいて、出力テキストが特定のカテゴリー候補からサンプリング"されない"場合があります。例えば、GitHubリポジトリのIssueをAIに解決させることを競う『Kaggle Konwinski Prize』では、LLMの出力は修正パッチ(修正プログラム)であり、任意のコードが出力されるため、Self-Consistencyで最頻値を採用するという手段を取ることが難しいです(全て1件出現で最頻値が1つに定まらないため)。
そこで、私達のチームの15位ソリューションでは、『編集距離選択』というアルゴリズムを考案しました。総編集距離が最も小さいテキストを選択することで、最も代表的で平均化されたテキストが選択され、外れ値が効果的に除外されるようになりました。
具体的な例で説明します。下記のような5つのユニークなテキストリストから、テキストを1つ選択するという問題設定を考えます。『編集距離選択』によって、最も平均的なI love cats.が選択されます。
=== Edit Distance Selection Check ===
Candidate Texts:
1: I like cats.
2: I love cats.
3: I love animals.
4: Animals are great.
5: Dogs are wonderful so much.
Selected (Most Average) Text: I love cats.
実装の詳細については、下記のコードをご確認ください。
プロダクトコード
"""src/konwinski_selector.py"""
from typing import Literal
class SelectionAlgorithm:
"""
候補テキストのリストから適切なテキストを選択するアルゴリズムを提供するクラス。
"""
def __init__(
self,
selection_type: Literal["index", "length", "median", "edit_distance"],
index: int = 0,
min_required_candidates: int = 1,
is_dynamic_update_min_required_candidates: bool = False,
):
"""
選択アルゴリズムを設定。
Args:
selection_type (Literal["index", "length", "median", "edit_distance"]): 選択アルゴリズムの種類
index (int): インデックス指定(index, length の場合に使用)
min_required_candidates (int): この数以上の候補がないとスキップ
"""
self.selection_type = selection_type
self.index = index
self.min_required_candidates = min_required_candidates
self.is_dynamic_update_min_required_candidates = (
is_dynamic_update_min_required_candidates
)
def select(self, candidates: list[str]) -> str:
"""
候補テキストのリストから適切なテキストを選択する。
Args:
candidates (list[str]): 候補テキストのリスト
Returns:
str: 選択されたテキスト
"""
if len(candidates) < self.min_required_candidates:
print(
f"Not enough candidates to select from. len(candidates): {len(candidates)} < min_required_candidates: {self.min_required_candidates}"
)
return None
else:
print(f"Number of candidates: {len(candidates)}")
if self.is_dynamic_update_min_required_candidates:
print("Dynamic update min_required_candidates")
print(
f"{self.min_required_candidates} -> {max(self.min_required_candidates, len(candidates))}"
)
self.min_required_candidates = max(
self.min_required_candidates, len(candidates)
)
if self.selection_type == "index":
return self.select_by_index(candidates)
elif self.selection_type == "length":
return self.select_by_length(candidates)
elif self.selection_type == "median":
return self.select_by_median(candidates)
elif self.selection_type == "edit_distance":
return self.select_by_edit_distance(candidates)
else:
print("Invalid selection type")
return self.select_by_index(candidates)
def select_by_index(self, candidates: list[str]) -> str:
"""インデックス指定でテキストを選択"""
try:
return candidates[self.index]
except Exception:
print("Invalid index, selecting first candidate")
try:
print("No candidates to select from")
return candidates[0]
except Exception:
print("No candidates to select from")
return None
def select_by_length(self, candidates: list[str]) -> str:
"""問題文の長さでソートし、インデックス指定で選択"""
sorted_candidates = sorted(candidates, key=len)
try:
return sorted_candidates[self.index]
except Exception:
print("Invalid index, selecting first candidate")
try:
print("No candidates to select from")
return sorted_candidates[0]
except Exception:
print("No candidates to select from")
return None
def select_by_median(self, candidates: list[str]) -> str:
"""問題文の長さでソートし、中央値を選択"""
sorted_candidates = sorted(candidates, key=len)
try:
return sorted_candidates[len(sorted_candidates) // 2]
except Exception:
try:
print("No candidates to select from")
return sorted_candidates[0]
except Exception:
print("No candidates to select from")
return None
def select_by_edit_distance(
self,
candidates: list[str],
num_candidates_limit: int = 10,
max_word_limit: int = 1000,
) -> str:
"""編集距離を計算し、最も平均的なテキストを選択"""
for candidate in candidates:
if len(candidate.split()) > max_word_limit:
print(f"Word limit exceeded: {max_word_limit}")
print("median selection is used instead")
return self.select_by_median(candidates)
if len(candidates) > num_candidates_limit:
# 問題文を長さでソートし、上位 num_candidates_limit 件のみ使用
candidates = sorted(candidates, key=len)[:num_candidates_limit]
min_distance = float("inf")
best_candidate = candidates[0]
for candidate in candidates:
total_distance = sum(
self.calculate_edit_distance(candidate, other)
for other in candidates
if candidate != other
)
print(f"Total edit distance for '{candidate}': {total_distance}")
if total_distance < min_distance:
min_distance = total_distance
best_candidate = candidate
return best_candidate
@staticmethod
def calculate_edit_distance(text1: str, text2: str) -> int:
"""
単語単位の Levenshtein 距離(編集距離)を求める。
Args:
text1 (str): 変換元の文字列
text2 (str): 変換先の文字列
Returns:
int: Levenshtein 距離(編集回数)
"""
words1 = text1.split() # 単語リストに分割
words2 = text2.split() # 単語リストに分割
len_w1, len_w2 = len(words1), len(words2)
if len_w1 < len_w2:
words1, words2 = words2, words1 # 長い方を words1 にする
previous_row = list(range(len_w2 + 1))
for i, w1 in enumerate(words1, 1):
current_row = [i]
for j, w2 in enumerate(words2, 1):
cost = 0 if w1 == w2 else 1
current_row.append(
min(
previous_row[j] + 1, # 削除
current_row[j - 1] + 1, # 挿入
previous_row[j - 1] + cost, # 置換
)
)
previous_row = current_row
return previous_row[-1]
if __name__ == "__main__":
candidates = [
"I like cats.",
"I love cats.",
"I love animals.",
"Animals are great.",
"Dogs are wonderful so much.",
]
# edit distance selection
selector = SelectionAlgorithm(selection_type="edit_distance")
selected_text = selector.select(candidates)
print("=== Edit Distance Selection Check ===")
print("Candidate Texts:")
for i, text in enumerate(candidates):
print(f"{i + 1}: {text}")
print(f"\nSelected (Most Average) Text: {selected_text}")
# median selection
selector = SelectionAlgorithm(selection_type="median")
selected_text = selector.select(candidates)
print("\n=== Median Selection Check ===")
print("Candidate Texts:")
for i, text in enumerate(candidates):
print(f"{i + 1}: {text}")
print(f"\nSelected (Median) Text: {selected_text}")
テストコード
"""tests/test_selector.py"""
import pytest
from src.konwinski_selector import SelectionAlgorithm
@pytest.fixture
def english_candidates():
return [
"This is a test of the selection algorithm.", # English
"Short sentence.", # English (shortest)
"This sentence is significantly longer than the others, making it stand out clearly.", # English (longest)
"A medium-length sentence.", # English
]
@pytest.fixture
def multilingual_candidates():
return [
"Hello, this is an English sentence.",
"こんにちは、これは日本語の文章です。",
"你好,这是一段中文文本。",
"This is another example in English.",
"これはもう一つの日本語の例です。",
"这是另一个中文示例。",
]
# ✅ 英語のみの通常テスト
def test_select_by_index(english_candidates):
"""Test selection by index (English only)"""
selector = SelectionAlgorithm(selection_type="index", index=0)
assert selector.select(english_candidates) == english_candidates[0]
selector = SelectionAlgorithm(selection_type="index", index=-1)
assert selector.select(english_candidates) == english_candidates[-1]
selector = SelectionAlgorithm(selection_type="index", index=100) # Out of range
assert (
selector.select(english_candidates) == english_candidates[0]
) # Default to first element
def test_select_by_length(english_candidates):
"""Test selection by sorted length (English only)"""
selector = SelectionAlgorithm(selection_type="length", index=0)
assert selector.select(english_candidates) == "Short sentence." # Shortest
selector = SelectionAlgorithm(selection_type="length", index=-1)
assert (
selector.select(english_candidates)
== "This sentence is significantly longer than the others, making it stand out clearly."
) # Longest
selector = SelectionAlgorithm(selection_type="length", index=100) # Out of range
assert (
selector.select(english_candidates) == "Short sentence."
) # Default to shortest
def test_select_by_median(english_candidates):
"""Test selection by median length (English only)"""
selector = SelectionAlgorithm(selection_type="median")
selected = selector.select(english_candidates)
assert selected == english_candidates[0] # This is the median length
def test_select_by_edit_distance(english_candidates):
"""Test selection by edit distance (choosing the most average text in English)"""
selector = SelectionAlgorithm(selection_type="edit_distance")
selected = selector.select(english_candidates)
assert selected == english_candidates[1] # Should select one of the candidates
def test_invalid_selection_type(english_candidates):
"""Test invalid selection type handling in English"""
selector = SelectionAlgorithm(selection_type="invalid_type")
selected = selector.select(english_candidates)
assert selected == english_candidates[0] # Default to first element
def test_empty_candidates():
"""Test handling of empty candidate list"""
selector = SelectionAlgorithm(selection_type="index", index=0)
assert selector.select([]) is None
def test_multilingual_support(multilingual_candidates):
"""Test multilingual selection (English, Japanese, Chinese)"""
selector = SelectionAlgorithm(selection_type="index", index=2)
assert (
selector.select(multilingual_candidates) == "你好,这是一段中文文本。"
) # Correctly selects Chinese sentence
selector = SelectionAlgorithm(selection_type="length", index=0)
shortest_text = min(multilingual_candidates, key=len)
assert (
selector.select(multilingual_candidates) == shortest_text
) # Shortest sentence
selector = SelectionAlgorithm(selection_type="edit_distance")
selected = selector.select(multilingual_candidates)
assert selected == multilingual_candidates[1] # Should select one of the candidates
def test_min_required_candidates():
"""min_required_candidates を満たさない場合に None を返すかテスト"""
selector = SelectionAlgorithm(selection_type="index", min_required_candidates=3)
# 2つしか候補がないので、None を返すはず
assert selector.select(["Option 1", "Option 2"]) is None
# 3つ以上ある場合、通常通り動作するはず
assert selector.select(["Option 1", "Option 2", "Option 3"]) == "Option 1"
def test_min_required_candidates_with_length():
"""min_required_candidates を length ベースでテスト"""
selector = SelectionAlgorithm(selection_type="length", min_required_candidates=2)
# 1つしかないので None を返す
assert selector.select(["Short text."]) is None
# 2つ以上あるので通常通り動作する
assert selector.select(["Short text.", "Longer text here."]) == "Short text."
def test_min_required_candidates_with_edit_distance():
"""min_required_candidates を edit_distance ベースでテスト"""
selector = SelectionAlgorithm(
selection_type="edit_distance", min_required_candidates=4
)
candidates = [
"A simple sentence.",
"Another sentence with some changes.",
"This one is slightly different.",
]
# 3つしかないので None を返す
assert selector.select(candidates) is None
# 4つ以上ある場合に通常通り動作する
candidates.append("A fourth sentence to test.")
assert selector.select(candidates) in candidates
def test_edit_distance_with_word_limit():
"""max_word_limit を超えた場合に median 選択が使われるかテスト"""
selector = SelectionAlgorithm(selection_type="edit_distance")
candidates = [
"Short text.",
"This is a medium length sentence.",
" ".join(["longword"] * 1001), # 1001 単語の長文 (max_word_limit を超える)
]
selected = selector.select_by_edit_distance(candidates, max_word_limit=1000)
assert selected == selector.select_by_median(candidates) # median が選ばれるはず
def test_edit_distance_with_candidate_limit():
"""num_candidates_limit を超えた場合に候補数が適切に制限されるかテスト"""
selector = SelectionAlgorithm(selection_type="edit_distance")
candidates = [
f"Sentence {i}" for i in range(50)
] # 50個の候補 (num_candidates_limit=10 を超える)
selected = selector.select_by_edit_distance(candidates, num_candidates_limit=10)
assert len(candidates) == 50 # 元のリストは変更されていない
assert selected in candidates[:10] # 上位10個の中から選ばれるはず
def test_edit_distance_no_limits():
"""num_candidates_limit や max_word_limit を引数に指定しない場合の動作確認"""
selector = SelectionAlgorithm(selection_type="edit_distance")
candidates = [
"A simple test sentence for testing.",
"Another example longer test.",
"Slightly longer and more descriptive sentence for testing.",
]
selected = selector.select_by_edit_distance(candidates)
assert selected == candidates[0]
まとめ
LLMの複数出力から最良のテキストを選ぶには、
候補が「カテゴリに依存するか・しないか」で適切な手法が異なります。
-
カテゴリ候補からサンプリングされる場合
→ Self-Consistency を使い、「最も多く出現した出力」を採用するのが有力 -
カテゴリが存在しない自由生成(コード・長文など)の場合
→ 編集距離選択が有効で、
全候補との距離が最も小さい「最も平均的・代表的なテキスト」を選択
これにより、外れ値を排除しつつ、高品質な出力選択が可能になります。
LLMの出力性能を向上させたい場合は、ぜひ試してみてください。
