はじめに
OpenAI が AI Agents SDK(しかも OSS!(^^))を発表し、いよいよ AI エージェントの実サービスへの活用が本格化しそうな気配が感じられます。
その一方で、これからの AI エージェントや RAG(Retrieval-Augmented Generation)をうまく使いこなすためには、RAG の基本をしっかりと整理し、RAG や AI エージェントで「できること」「組み合わせて可能になること」をあらためて考えることが重要だと感じています。
OpenAI Agents SDK
今回は、自分自身の理解を整理する目的も兼ねて、LlamaIndex のコードを読み解きながら、基本的な使い方やその裏側の実装がどうなっているのか をまとめてみました。
対象とする LlamaIndex のバージョンは、2025年3月7日時点の最新バージョン「0.12.23」です。
今回は第2弾として、QueryEngineによる生成AI検索(主にindex.as_query_engine)する処理を解説します。
※第1弾のRetrieverはよろしかったら下記を参照してください
コードの解説
全体の概要は以下になります。
(前提として、事前にVector Indexは生成されているものとしています)
実装コードサンプル
QueryEngineを用いたLLM (GPT-4o)検索のコードは以下になります。
大きくは、以下の流れになります。
-
RetrieverQueryEngine
モジュールを生成する。このモジュールはRAGからデータを検索(Retriever
)し、プロンプトに埋め込み生成AIに問い合わせを行い結果を返却する(Response Synthesizer
) ※下記③ - クエリからindexをサーチしプロンプトに埋め込むデータ(ノード)を取得する ※下記④前半
- 取得したノードとクエリをプロンプトに埋め込み、生成AIに問い合わせを行い結果を返却する ※下記④後半
1 # APIキーやLLMの設定
2 os.environ["OPENAI_API_KEY"]="xxxxxxxxxx"
3 Settings.llm = OpenAI(model="gpt-4o", temperature=0.0, )
4
5 # ⭐️ ①保存したVectorStoreIndexからStorageContextを復元
6 storage_context = StorageContext.from_defaults(persist_dir=LLAMAINDEX_PERSIST_FOLDER) #セーブしたフォルダ名
7 # ⭐️ ②StorageContextからindex(VectorStoreIndex)を復元
8 index = load_index_from_storage(storage_context)
9 # ⭐️ ③indexをQueryEngine として生成
10 query_engine = index.as_query_engine()
11 # ⭐️ ④クエリを実行
12 response = query_engine.query("動物の死骸がどうにありましたがどうすればいいですか?")
13
10 # 結果を表示
11 print(response)
以降、それぞれの処理でどのような処理が実際に行われているのか、順次コードを追って見ていきます。なお、「①保存したVectorStoreIndexからStorageContextを復元」および「②StorageContextからindex(VectorStoreIndex)を復元」は、Retrieverと同じですので、第1弾の記事 を参照してください。
3. indexをQueryEngine として生成
② query_engine = index.as_query_engine()
ここでは、まずRetrieverモジュールを生成します(374行目)。続いてllmモジュールの生成を行います(375行目。LLMモジュールの詳細はAppendixを参照)。381行目でこの2つのモジュール使ってRetrieverQueryEngineモジュールを生成しますが、この際、Response Synthesizerモジュールの生成を行い、(指定されていれば)node_postprocessorsと合わせて組み込まれます。
・・・
25 class BaseIndex(Generic[IS], ABC):
26 """Base LlamaIndex.
27
28 Args:
29 nodes (List[Node]): List of nodes to index
30 show_progress (bool): Whether to show tqdm progress bars. Defaults to False.
31 """
・・・
361 def as_query_engine(
362 self, llm: Optional[LLMType] = None, **kwargs: Any
363 ) -> BaseQueryEngine:
364 """Convert the index to a query engine.
365
366 Calls `index.as_retriever(**kwargs)` to get the retriever and then wraps it in a
367 `RetrieverQueryEngine.from_args(retriever, **kwrags)` call.
368 """
369 # NOTE: lazy import
370 from llama_index.core.query_engine.retriever_query_engine import (
371 RetrieverQueryEngine,
372 )
373
# ⭐️ index.as_retrieverと同じ
374 retriever = self.as_retriever(**kwargs)
375 llm = (
376 resolve_llm(llm, callback_manager=self._callback_manager)
377 if llm
378 else Settings.llm
379 )
380
# ⭐️ retrieverとLLM情報を与えてRetrieverQueryEngineを生成
381 return RetrieverQueryEngine.from_args(
382 retriever,
383 llm=llm,
384 **kwargs,
385 )
(補足)Node Postprocessors利用時のindex.as_query_engine()サンプルを下記に示します。
query_engine = index.as_query_engine(
node_postprocessors=[
TimeWeightedPostprocessor(
time_decay=0.5, time_access_refresh=False, top_k=1
)
]
)
ここからは、RetrieverQueryEngine.from_args()
の内部で何が行われてるいるのかを解説して行きます。
RetrieverQueryEngine.from_args()
はクラスメソッドですので、ダイレクトに呼ばれます。この中で get_response_synthesizer()
を呼び出します(43行目)。
最後に、return cls()
(116行目)によりRetrieverQueryEngineのインスタンスが生成され返却されますが、この時、__init__()
がと呼ばれ、パラメータで受け取ったRetrieverインスタンス(42行目)、生成したResponse Synthesizerインスタンス(43行目)、(指定があれば)node_postprocessorsインスタスをインスタンス変数として登録します(48行目)。
RetrieverQueryEngineインスタンスにクエリ(質問)を与えることで、RAGの検索し生成AIへの問い合わせ行い応答生成までが実行されます。
・・・
25 class RetrieverQueryEngine(BaseQueryEngine):
26 """Retriever query engine.
27
28 Args:
29 retriever (BaseRetriever): A retriever object.
30 response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
31 object.
32 callback_manager (Optional[CallbackManager]): A callback manager.
33 """
34
35 def __init__(
36 self,
37 retriever: BaseRetriever,
38 response_synthesizer: Optional[BaseSynthesizer] = None,
39 node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
40 callback_manager: Optional[CallbackManager] = None,
41 ) -> None:
42 self._retriever = retriever
# ⭐️ response_synthesizerを登録
43 self._response_synthesizer = response_synthesizer or get_response_synthesizer(
44 llm=Settings.llm,
45 callback_manager=callback_manager or Settings.callback_manager,
46 )
47
48 self._node_postprocessors = node_postprocessors or []
49 callback_manager = (
50 callback_manager or self._response_synthesizer.callback_manager
51 )
# ⭐️ node_postprocessorの登録(指定があれば)
52 for node_postprocessor in self._node_postprocessors:
53 node_postprocessor.callback_manager = callback_manager
54 super().__init__(callback_manager=callback_manager)
55
56 def _get_prompt_modules(self) -> PromptMixinType:
57 """Get prompt sub-modules."""
58 return {"response_synthesizer": self._response_synthesizer}
59
60 @classmethod
61 def from_args(
62 cls,
63 retriever: BaseRetriever,
64 llm: Optional[LLM] = None,
65 response_synthesizer: Optional[BaseSynthesizer] = None,
66 node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
67 callback_manager: Optional[CallbackManager] = None,
68 # response synthesizer args
69 response_mode: ResponseMode = ResponseMode.COMPACT,
70 text_qa_template: Optional[BasePromptTemplate] = None,
71 refine_template: Optional[BasePromptTemplate] = None,
72 summary_template: Optional[BasePromptTemplate] = None,
73 simple_template: Optional[BasePromptTemplate] = None,
74 output_cls: Optional[Type[BaseModel]] = None,
75 use_async: bool = False,
76 streaming: bool = False,
77 **kwargs: Any,
78 ) -> "RetrieverQueryEngine":
79 """Initialize a RetrieverQueryEngine object.".
80
81 Args:
82 retriever (BaseRetriever): A retriever object.
83 llm (Optional[LLM]): An instance of an LLM.
84 response_synthesizer (Optional[BaseSynthesizer]): An instance of a response
85 synthesizer.
86 node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
87 node postprocessors.
88 callback_manager (Optional[CallbackManager]): A callback manager.
89 response_mode (ResponseMode): A ResponseMode object.
90 text_qa_template (Optional[BasePromptTemplate]): A BasePromptTemplate
91 object.
92 refine_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
93 summary_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
94 simple_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
95 output_cls (Optional[Type[BaseModel]]): The pydantic model to pass to the
96 response synthesizer.
97 use_async (bool): Whether to use async.
98 streaming (bool): Whether to use streaming.
99 """
100 llm = llm or Settings.llm
101
# ⭐️get_response_synthesizer()によりResponse Synthesizerを生成する
102 response_synthesizer = response_synthesizer or get_response_synthesizer(
103 llm=llm,
104 text_qa_template=text_qa_template,
105 refine_template=refine_template,
106 summary_template=summary_template,
107 simple_template=simple_template,
108 response_mode=response_mode,
109 output_cls=output_cls,
110 use_async=use_async,
111 streaming=streaming,
112 )
113
114 callback_manager = callback_manager or Settings.callback_manager
115
116 return cls(
117 retriever=retriever,
118 response_synthesizer=response_synthesizer,
119 callback_manager=callback_manager,
120 node_postprocessors=node_postprocessors,
121 )
・・・
query_engine/retriever_query_engine.py の全文はこちら
次に、上記102行目の get_response_synthesizer()
の処理を具体的に見ていきます。
get_response_synthesizer()
では、デフォルトで CompactAndRefine インスタンス が生成されます(81行目)。CompactAndRefineとは、最も基本的なパターンで必要がある場合のみ(簡単に言うと top_k の数が多い場合や 1ノードのサイズが大きく、テキストチャンクが LLM の入力トークン数に収まらなくなる場合など)チャンクをまとめて(compact)、その後 refine を行う処理を行います。
通常は、チャンクがプロンプト制約を超えそうな場合にのみ動作するため、常に refine 処理が走るわけではありません。
・・・
33 def get_response_synthesizer(
34 llm: Optional[LLM] = None,
35 prompt_helper: Optional[PromptHelper] = None,
36 text_qa_template: Optional[BasePromptTemplate] = None,
37 refine_template: Optional[BasePromptTemplate] = None,
38 summary_template: Optional[BasePromptTemplate] = None,
39 simple_template: Optional[BasePromptTemplate] = None,
40 response_mode: ResponseMode = ResponseMode.COMPACT,
41 callback_manager: Optional[CallbackManager] = None,
42 use_async: bool = False,
43 streaming: bool = False,
44 structured_answer_filtering: bool = False,
45 output_cls: Optional[Type[BaseModel]] = None,
46 program_factory: Optional[
47 Callable[[BasePromptTemplate], BasePydanticProgram]
48 ] = None,
49 verbose: bool = False,
50 ) -> BaseSynthesizer:
51 """Get a response synthesizer."""
52 text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
53 refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL
54 simple_template = simple_template or DEFAULT_SIMPLE_INPUT_PROMPT
55 summary_template = summary_template or DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
56
57 callback_manager = callback_manager or Settings.callback_manager
58 llm = llm or Settings.llm
59 prompt_helper = (
60 prompt_helper
61 or Settings._prompt_helper
62 or PromptHelper.from_llm_metadata(
63 llm.metadata,
64 )
65 )
66
67 if response_mode == ResponseMode.REFINE:
68 return Refine(
69 llm=llm,
70 callback_manager=callback_manager,
71 prompt_helper=prompt_helper,
72 text_qa_template=text_qa_template,
73 refine_template=refine_template,
74 output_cls=output_cls,
75 streaming=streaming,
76 structured_answer_filtering=structured_answer_filtering,
77 program_factory=program_factory,
78 verbose=verbose,
79 )
80 elif response_mode == ResponseMode.COMPACT:
# ⭐️CompactAndRefine インスタンス を生成
81 return CompactAndRefine(
82 llm=llm,
83 callback_manager=callback_manager,
84 prompt_helper=prompt_helper,
85 text_qa_template=text_qa_template,
86 refine_template=refine_template,
87 output_cls=output_cls,
88 streaming=streaming,
89 structured_answer_filtering=structured_answer_filtering,
90 program_factory=program_factory,
91 verbose=verbose,
92 )
93 elif response_mode == ResponseMode.TREE_SUMMARIZE:
・・・
104 elif response_mode == ResponseMode.SIMPLE_SUMMARIZE:
・・・
112 elif response_mode == ResponseMode.GENERATION:
・・・
120 elif response_mode == ResponseMode.ACCUMULATE:
・・・
130 elif response_mode == ResponseMode.COMPACT_ACCUMULATE:
・・・
140 elif response_mode == ResponseMode.NO_TEXT:
・・・
145 elif response_mode == ResponseMode.CONTEXT_ONLY:
・・・
150 else:
raise ValueError(f"Unknown mode: {response_mode}")
・・・
response_synthesizers/factory.pyの全文はこちら
CompactAndRefine クラスは __init__()
を持っておらず親クラスである Refine クラスの __init__()
が呼ばれます。 __init__()
では、プロンプトテンプレート( text_qa_template や refine_template )や program_factory が初期化されます。
program_factory は、LLMとのやり取り(プロンプトの送信とレスポンスの取得)を行うための プログラムオブジェクト(PydanticProgramを継承) を作成するファクトリ関数です。
146行目のprogram_factoryでは、_default_program_factory=DefaultRefineProgramクラスが設定されます。このprogram_factoryがLLMに問い合わせを行い結果を得ます。DefaultRefineProgramについては後述します。
・・・
108 class Refine(BaseSynthesizer):
109 """Refine a response to a query across text chunks."""
110
111 def __init__(
112 self,
113 llm: Optional[LLM] = None,
114 callback_manager: Optional[CallbackManager] = None,
115 prompt_helper: Optional[PromptHelper] = None,
116 text_qa_template: Optional[BasePromptTemplate] = None,
117 refine_template: Optional[BasePromptTemplate] = None,
118 output_cls: Optional[Type[BaseModel]] = None,
119 streaming: bool = False,
120 verbose: bool = False,
121 structured_answer_filtering: bool = False,
122 program_factory: Optional[
123 Callable[[BasePromptTemplate], BasePydanticProgram]
124 ] = None,
125 ) -> None:
126 super().__init__(
127 llm=llm,
128 callback_manager=callback_manager,
129 prompt_helper=prompt_helper,
130 streaming=streaming,
131 )
132 self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
133 self._refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL
134 self._verbose = verbose
135 self._structured_answer_filtering = structured_answer_filtering
136 self._output_cls = output_cls
137
138 if self._streaming and self._structured_answer_filtering:
139 raise ValueError(
140 "Streaming not supported with structured answer filtering."
141 )
142 if not self._structured_answer_filtering and program_factory is not None:
143 raise ValueError(
144 "Program factory not supported without structured answer filtering."
145 )
# ⭐️ _default_program_factory=DefaultRefineProgramクラスの設定
146 self._program_factory = program_factory or self._default_program_factory
・・・
response_synthesizers/refin.pyの全文はこちら
4. Queryを実行
④ response = query_engine.query("ポケモンについて教えて")
ここでは、生成した query_engine(RetrieverQueryEngineインスタンス) にクエリを与え、LLM を使って最終的な回答(生成AIの出力)を取得します。
query_engine.query()
は、まず BaseQueryEngine クラスの query()
(47行目)が実行され、そこから RetrieverQueryEngine クラスの _query()
(173行目)が呼びれます。
・・・
30 class BaseQueryEngine(ChainableMixin, PromptMixin, DispatcherSpanMixin):
31 """Base query engine."""
・・・
# ⭐️ query()の呼び出し
47 def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
48 dispatcher.event(QueryStartEvent(query=str_or_query_bundle))
49 with self.callback_manager.as_trace("query"):
50 if isinstance(str_or_query_bundle, str):
51 str_or_query_bundle = QueryBundle(str_or_query_bundle)
# ⭐️ RetrieverQueryEngine クラスの _query()を呼び出し
52 query_result = self._query(str_or_query_bundle)
53 dispatcher.event(
54 QueryEndEvent(query=str_or_query_bundle, response=query_result)
55 )
56 return query_result
57
・・・
この _query()
内では、大きく2つのフェーズがあります。
- フェーズ1
Retrieverインスタンスを使いクエリに類似するノードをVectorStoreIndexから取得(178行目) - フェーズ2
取得したノード情報を ResponseSynthesizer/LLM に渡して回答を取得する(179行目)
・・・
25 class RetrieverQueryEngine(BaseQueryEngine):
26 """Retriever query engine.
27
28 Args:
29 retriever (BaseRetriever): A retriever object.
30 response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
31 object.
32 callback_manager (Optional[CallbackManager]): A callback manager.
33 """
・・・
173 def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
174 """Answer a query."""
175 with self.callback_manager.event(
176 CBEventType.QUERY, payload={EventPayload.QUERY_STR: 177 query_bundle.query_str}
177 ) as query_event:
178
# ⭐️フェーズ1
178 nodes = self.retrieve(query_bundle)
# ⭐️フェーズ2
179 response = self._response_synthesizer.synthesize(
180 query=query_bundle,
181 nodes=nodes,
182 )
183 query_event.on_end(payload={EventPayload.RESPONSE: response})
184
185 return response
・・・
204 @property
205 def retriever(self) -> BaseRetriever:
206 """Get the retriever object."""
207 return self._retriever
query_engine/retriever_query_engine.pyの全文はこちら
以降、フェーズ1、フェーズ2に分けてコードを説明します。
4.1 フェーズ1(Retrieve)
ここでは、Retriever がノードを取得する流れを見ていきます。
(Retriverの処理は、「RAG(LlamaIndex)をDeepに理解しよう!コード解説(as_retriver)」と同一になりますので、こちらの記事を見た人は読み飛ばしてください)
245行目の self._retrieve()
で、VectorIndexRetriverインスタンスの _retrieve()
が実行されます。
・・・
42 class BaseRetriever(ChainableMixin, PromptMixin, DispatcherSpanMixin):
"""Base retriever."""
・・・
222 def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
223 """Retrieve nodes given query.
224
225 Args:
226 str_or_query_bundle (QueryType): Either a query string or
227 a QueryBundle object.
228
229 """
230 self._check_callback_manager()
231 dispatcher.event(
232 RetrievalStartEvent(
233 str_or_query_bundle=str_or_query_bundle,
234 )
235 )
236 if isinstance(str_or_query_bundle, str):
237 query_bundle = QueryBundle(str_or_query_bundle)
238 else:
239 query_bundle = str_or_query_bundle
240 with self.callback_manager.as_trace("query"):
241 with self.callback_manager.event(
242 CBEventType.RETRIEVE,
243 payload={EventPayload.QUERY_STR: query_bundle.query_str},
244 ) as retrieve_event:
# ⭐️VectorIndexRetriverインスタンスの _retrieve()を実行
245 nodes = self._retrieve(query_bundle)
246 nodes = self._handle_recursive_retrieval(query_bundle, nodes)
247 retrieve_event.on_end(
248 payload={EventPayload.NODES: nodes},
249 )
250 dispatcher.event(
251 RetrievalEndEvent(
252 str_or_query_bundle=str_or_query_bundle,
253 nodes=nodes,
254 )
255 )
256 return nodes
・・・
VectorIndexRetrieverクラスの _retrieve()
(92行目)では以下の処理を行います。
- _get_nodes_with_embeddings() を呼び出してリターンします(103行目)
_get_nodes_with_embeddings()
では、_build_vector_store_query()
を呼び出し、VectorStoreQuery オブジェクトを生成します(179行目) - VectorStore(SimpleVectorStore) を
query()
しノード(データ)を取得します(180行目) -
_build_node_list_from_query_result()
を呼び出し、検索されたノードindexからノードIDリストを作り、ドキュメントストアから対応するノード(データ)を取得します(146行目) - 検索されたデータ(Node)とScoreで、NodeWithScoreオブジェクトを生成し、リターンします(172-173行目)
・・・
24 class VectorIndexRetriever(BaseRetriever):
25 """Vector index retriever.
26
27 Args:
28 index (VectorStoreIndex): vector store index.
29 similarity_top_k (int): number of top k results to return.
30 vector_store_query_mode (str): vector store query mode
31 See reference for VectorStoreQueryMode for full list of supported modes.
32 filters (Optional[MetadataFilters]): metadata filters, defaults to None
33 alpha (float): weight for sparse/dense retrieval, only used for
34 hybrid query mode.
35 doc_ids (Optional[List[str]]): list of documents to constrain search.
36 vector_store_kwargs (dict): Additional vector store specific kwargs to pass
37 through to the vector store at query time.
38
39 """
・・・
# ⭐️VectorIndexのノード検索
92 def _retrieve(
93 self,
94 query_bundle: QueryBundle,
95 ) -> List[NodeWithScore]:
96 if self._vector_store.is_embedding_query:
97 if query_bundle.embedding is None and len(query_bundle.embedding_strs) > 0:
98 query_bundle.embedding = (
99 self._embed_model.get_agg_embedding_from_queries(
100 query_bundle.embedding_strs
101 )
102 )
# ⭐️ _get_nodes_with_embeddings()を呼び出す
103 return self._get_nodes_with_embeddings(query_bundle)
・・・
134 def _build_node_list_from_query_result(
135 self, query_result: VectorStoreQueryResult
136 ) -> List[NodeWithScore]:
137 if query_result.nodes is None:
138 # NOTE: vector store does not keep text and returns node indices.
139 # Need to recover all nodes from docstore
140 if query_result.ids is None:
141 raise ValueError(
142 "Vector store query result should return at "
143 "least one of nodes or ids."
144 )
145 assert isinstance(self._index.index_struct, IndexDict)
# ⭐️検索されたノードindexからノードIDリストを作り、ドキュメントストアから対応するノード(データ)を取得する
146 node_ids = [
147 self._index.index_struct.nodes_dict[idx] for idx in query_result.ids
148 ]
149 nodes = self._docstore.get_nodes(node_ids)
150 query_result.nodes = nodes
151 else:
・・・
64
165 log_vector_store_query_result(query_result)
166
167 node_with_scores: List[NodeWithScore] = []
168 for ind, node in enumerate(query_result.nodes):
169 score: Optional[float] = None
170 if query_result.similarities is not None:
171 score = query_result.similarities[ind]
# ⭐️検索されたデータ(Node)とScoreでNodeWithScoreオブジェクトを生成し追加する
172 node_with_scores.append(NodeWithScore(node=node, score=score))
173 return node_with_scores
・・・
176 def _get_nodes_with_embeddings(
177 self, query_bundle_with_embeddings: QueryBundle
178 ) -> List[NodeWithScore]:
179 query = self._build_vector_store_query(query_bundle_with_embeddings)
# ⭐️ VectorStore(vector_stores.simple.SimpleVectorStore)をquery()し結果を取得
180 query_result = self._vector_store.query(query, **self._kwargs)
181 return self._build_node_list_from_query_result(query_result)
・・・
indices/vector_sore/retrievres/retriever.pyの全文はこちら
_build_node_list_from_query_result()
は、先ほど述べたように元のデータと検索時のスコアを構造体(NodeWithScore)にして返却します。
・・・
134 def _build_node_list_from_query_result(
135 self, query_result: VectorStoreQueryResult
136 ) -> List[NodeWithScore]:
137 if query_result.nodes is None:
138 # NOTE: vector store does not keep text and returns node indices.
139 # Need to recover all nodes from docstore
140 if query_result.ids is None:
141 raise ValueError(
142 "Vector store query result should return at "
143 "least one of nodes or ids."
144 )
145 assert isinstance(self._index.index_struct, IndexDict)
146 node_ids = [
147 self._index.index_struct.nodes_dict[idx] for idx in query_result.ids
148 ]
149 nodes = self._docstore.get_nodes(node_ids)
150 query_result.nodes = nodes
151 else:
・・・
164
165 log_vector_store_query_result(query_result)
166
167 node_with_scores: List[NodeWithScore] = []
168 for ind, node in enumerate(query_result.nodes):
169 score: Optional[float] = None
170 if query_result.similarities is not None:
171 score = query_result.similarities[ind]
# ⭐️ 元のデータと検索時のスコアを構造体(NodeWithScore)にして返却
172 node_with_scores.append(NodeWithScore(node=node, score=score))
173
174 return node_with_scores
・・・
indices/vector_sore/retrievres/retriever.pyの全文はこちら
上記180行目の self._vector_store.query()
メソッドは、先ほどのStorageContexの生成で出てきた、storage/storage_context.py
の134行目 vector_stores = SimpleVectorStore.from_namespaced_persist_dir()
で生成した SimpleVectorStore
クラスが使われます。
- query()を呼び出し、類似度が高いVestorストア上のノードを検索する(317行目)
- デフォルトのqueryモードで実行され、
get_top_k_embeddings()
を呼び出す(376行目)
・・・
139 class SimpleVectorStore(BasePydanticVectorStore):
140 """Simple Vector Store.
141
142 In this vector store, embeddings are stored within a simple, in-memory dictionary.
143
144 Args:
145 simple_vector_store_data_dict (Optional[dict]): data dict
146 containing the embeddings and doc_ids. See SimpleVectorStoreData
147 for more details.
148 """
・・・
# ⭐️queryと類似度が高いVestorストア上のノードを検索
317 def query(
318 self,
319 query: VectorStoreQuery,
320 **kwargs: Any,
321 ) -> VectorStoreQueryResult:
322 """Get nodes for response."""
・・・
337
338 if query.node_ids is not None:
339 available_ids = set(query.node_ids)
340
341 def node_filter_fn(node_id: str) -> bool:
342 return node_id in available_ids
343
344 else:
・・・
349 node_ids = []
350 embeddings = []
351 # TODO: consolidate with get_query_text_embedding_similarities
352 for node_id, embedding in self.data.embedding_dict.items():
353 if node_filter_fn(node_id) and query_filter_fn(node_id):
354 node_ids.append(node_id)
355 embeddings.append(embedding)
356
357 query_embedding = cast(List[float], query.query_embedding)
358
359 if query.mode in LEARNER_MODES:
・・・
366 elif query.mode == MMR_MODE:
・・・
375 elif query.mode == VectorStoreQueryMode.DEFAULT:
# ⭐️デフォルトのqueryモードが実行される
376 top_similarities, top_ids = get_top_k_embeddings(
377 query_embedding,
378 embeddings,
379 similarity_top_k=query.similarity_top_k,
380 embedding_ids=node_ids,
381 )
382 else:
383 raise ValueError(f"Invalid query mode: {query.mode}")
384
385 return VectorStoreQueryResult(similarities=top_similarities, ids=top_ids)
・・・
get_top_k_embeddings()
では、指定された数(top_k)の類似度がデータを選択します。
-
get_top_k_embeddings()
が呼び出されます(11行目) - 類似度比較関数にdefalutの
base/embeddings/base/simirarity
が設定されます(23行目) - 類似度の検証を行います(30行目)
- heapqを利用し
top_k
数を超えたらキューの中から類似度が最も低いものを削除します(34行目)
・・・
# ⭐️get_top_k_embeddings()の呼び出し
11 def get_top_k_embeddings(
12 query_embedding: List[float],
13 embeddings: List[List[float]],
14 similarity_fn: Optional[Callable[..., float]] = None,
15 similarity_top_k: Optional[int] = None,
16 embedding_ids: Optional[List] = None,
17 similarity_cutoff: Optional[float] = None,
18 ) -> Tuple[List[float], List]:
19 """Get top nodes by similarity to the query."""
20 if embedding_ids is None:
21 embedding_ids = list(range(len(embeddings)))
22
# ⭐️defalutの類似度比較関数であるbase/embeddings/base/simirarityが呼ばれる
23 similarity_fn = similarity_fn or default_similarity_fn
24
25 embeddings_np = np.array(embeddings)
26 query_embedding_np = np.array(query_embedding)
27
28 similarity_heap: List[Tuple[float, Any]] = []
29 for i, emb in enumerate(embeddings_np):
# ⭐️類似度の検証
30 similarity = similarity_fn(query_embedding_np, emb) # type: ignore[arg-type]
31 if similarity_cutoff is None or similarity > similarity_cutoff:
32 heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))
33 if similarity_top_k and len(similarity_heap) > similarity_top_k:
# ⭐️top_k数を超えたらheapqで類似度が最も低いものを削除
34 heapq.heappop(similarity_heap)
35 result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)
36
37 result_similarities = [s for s, _ in result_tups]
38 result_ids = [n for _, n in result_tups]
39
40 return result_similarities, result_ids
・・・
similarityは、defaultではコサイン類似度で類似度が測定されます。
(indexにある全ノードと比較されるので件数が極端に多いときは処理性能には注意が必要と思われます)
・・・
50 def similarity(
51 embedding1: Embedding,
52 embedding2: Embedding,
53 mode: SimilarityMode = SimilarityMode.DEFAULT,
54 ) -> float:
55 """Get embedding similarity."""
56 if mode == SimilarityMode.EUCLIDEAN:
57 # Using -euclidean distance as similarity to achieve same ranking order
58 return -float(np.linalg.norm(np.array(embedding1) - np.array(embedding2)))
59 elif mode == SimilarityMode.DOT_PRODUCT:
60 return np.dot(embedding1, embedding2)
61 else:
62 # ⭐️COS類似度の算出
63 product = np.dot(embedding1, embedding2)
64 norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
65 return product / norm
・・・
以上で、Index(VetrorStore)から、contextとしてプロンプトに埋め込むデータが取得できました。
4.2 フェーズ2(ResponseSynthesizer/LLM)
フェーズ1で取得したノード情報(プロンプトに埋め込むcontextとクエリをResponseSynthesizer(LLM) に渡して回答生成を行います。
上述の response = self._response_synthesizer.synthesize()
で、response_synthesizers/base.py モジュールの synthesize()
が実行されます(199行目)。
241行目のself.get_response()では、(デフォルトのResponseSynthesizerである)CompactAndRefineクラスのget_response()が呼ばれます。
・・・
67 class BaseSynthesizer(ChainableMixin, PromptMixin, DispatcherSpanMixin):
68 """Response builder class."""
・・・
199 def synthesize(
200 self,
201 query: QueryTextType,
202 nodes: List[NodeWithScore],
203 additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
204 **response_kwargs: Any,
205 ) -> RESPONSE_TYPE:
206 dispatcher.event(
207 SynthesizeStartEvent(
208 query=query,
209 )
210 )
211
212 if len(nodes) == 0:
213 if self._streaming:
214 empty_response_stream = StreamingResponse(
215 response_gen=empty_response_generator()
216 )
217 dispatcher.event(
218 SynthesizeEndEvent(
219 query=query,
220 response=empty_response_stream,
221 )
222 )
223 return empty_response_stream
224 else:
225 empty_response = Response("Empty Response")
226 dispatcher.event(
227 SynthesizeEndEvent(
228 query=query,
229 response=empty_response,
230 )
231 )
232 return empty_response
233
234 if isinstance(query, str):
235 query = QueryBundle(query_str=query)
236
237 with self._callback_manager.event(
238 CBEventType.SYNTHESIZE,
239 payload={EventPayload.QUERY_STR: query.query_str},
240 ) as event:
# ⭐️ Compact & Refineのget_response()を呼び出す
241 response_str = self.get_response(
242 query_str=query.query_str,
243 text_chunks=[
244 n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
245 ],
246 **response_kwargs,
247 )
248
249 additional_source_nodes = additional_source_nodes or []
250 source_nodes = list(nodes) + list(additional_source_nodes)
251
252 response = self._prepare_response_output(response_str, source_nodes)
253
254 event.on_end(payload={EventPayload.RESPONSE: response})
255
256 dispatcher.event(
257 SynthesizeEndEvent(
258 query=query,
259 response=response,
260 )
261 )
262 return response
・・・
response_synthesizers/base.pyの全文はこちら
CompactAndRefineクラスのget_response()
では43行目で親クラス(Refine)の get_response()
を呼び出します。
・・・
11 class CompactAndRefine(Refine):
12 """Refine responses across compact text chunks."""
・・・
31 def get_response(
32 self,
33 query_str: str,
34 text_chunks: Sequence[str],
35 prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
36 **response_kwargs: Any,
37 ) -> RESPONSE_TEXT_TYPE:
38 """Get compact response."""
39 # use prompt helper to fix compact text_chunks under the prompt limitation
40 # TODO: This is a temporary fix - reason it's temporary is that
41 # the refine template does not account for size of previous answer.
42 new_texts = self._make_compact_text_chunks(query_str, text_chunks)
# ⭐️親クラス(Refine)の get_response()を呼び出す
43 return super().get_response(
44 query_str=query_str,
45 text_chunks=new_texts,
46 prev_response=prev_response,
47 **response_kwargs,
48 )
・・・
response_synthesizers/compact_and_refin.pyの全文はこちら
Refineクラスのget_response()が実際のメイン処理になり、この中でプロンプトの組立てから生成AIへの問い合わせまでを実施します。
この処理は長いですが順を追って説明します。
(1) CompactAndRefineクラスの get_response()
から親クラス Refine のget_response()が呼ばれる(163行目)
(2) 初回の処理では prev_response が None なので、_give_response_single() が呼び出される(179行目)
(3) _give_response_single()を実行される(220行目)
(4)この中で self._program_factory(text_qa_template) が呼ばれ、プロンプトテンプレートをもとに プログラムオブジェクト(LLM呼び出しラッパー) が生成される(233行目)
Refine クラスの __init__()
(146行目)で self._program_factory に self._default_program_factory が設定されており、structured_answer_filtering=False (LLMの応答を文字列で処理)の場合、DefaultRefineProgram(...) が返される(214行目)。DefaultRefineProgram は BasePydanticProgram を継承し、LLMにプロンプトを送って回答を得るプログラムオブジェクトで、このインスタンスは「関数のように呼び出せる」ため、後で program() として使えます(call() が実行される)。
(5) program()を実行しLLMに問い合わせを行い結果を取得する(241行目)
(6) DefaultRefineProgram.__call__()
が実行される(75行目)
(7) self._llm.predict()
によって、実際に LLM へプロンプトが送信され、回答が得られる(84行目 )
(8) 取得した回答を構造体で返す( 89行目)
(9) (8)でquery_satisfiedは固定でTrueを設定しており回答をresponseに設定(248行目)
(10) responseをリターンに処理を終了する(273行目)
・・・
38 class StructuredRefineResponse(BaseModel):
39 """
40 Used to answer a given query based on the provided context.
41
42 Also indicates if the query was satisfied with the provided answer.
43 """
44
45 answer: str = Field(
46 description="The answer for the given query, based on the context and not "
47 "prior knowledge."
48 )
49 query_satisfied: bool = Field(
50 description="True if there was enough context given to provide an answer "
51 "that satisfies the query."
52 )
53
54
55 class DefaultRefineProgram(BasePydanticProgram):
56 """
57 Runs the query on the LLM as normal and always returns the answer with
58 query_satisfied=True. In effect, doesn't do any answer filtering.
59 """
60
61 def __init__(
62 self,
63 prompt: BasePromptTemplate,
64 llm: LLM,
65 output_cls: Optional[Type[BaseModel]] = None,
66 ):
67 self._prompt = prompt
68 self._llm = llm
69 self._output_cls = output_cls
70
71 @property
72 def output_cls(self) -> Type[BaseModel]:
73 return StructuredRefineResponse
74
# ⭐️(6) '__call__()'が実行され、プロンプトをLLMに与えて回答を取得する
75 def __call__(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
76 if self._output_cls is not None:
77 answer = self._llm.structured_predict(
78 self._output_cls,
79 self._prompt,
80 **kwds,
81 )
82 if isinstance(answer, BaseModel):
83 answer = answer.model_dump_json()
84 else:
# ⭐️ (7) self._llm.predict() でLLMからの回答を取得
85 answer = self._llm.predict(
86 self._prompt,
87 **kwds,
88 )
# ⭐️ (8) 取得した回答を構造体で返す
89 return StructuredRefineResponse(answer=answer, query_satisfied=True)
・・・
108 class Refine(BaseSynthesizer):
109 """Refine a response to a query across text chunks."""
110
111 def __init__(
112 self,
113 llm: Optional[LLM] = None,
114 callback_manager: Optional[CallbackManager] = None,
115 prompt_helper: Optional[PromptHelper] = None,
116 text_qa_template: Optional[BasePromptTemplate] = None,
117 refine_template: Optional[BasePromptTemplate] = None,
118 output_cls: Optional[Type[BaseModel]] = None,
119 streaming: bool = False,
120 verbose: bool = False,
121 structured_answer_filtering: bool = False,
122 program_factory: Optional[
123 Callable[[BasePromptTemplate], BasePydanticProgram]
124 ] = None,
125 ) -> None:
126 super().__init__(
127 llm=llm,
128 callback_manager=callback_manager,
129 prompt_helper=prompt_helper,
130 streaming=streaming,
131 )
132 self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
133 self._refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL
134 self._verbose = verbose
135 self._structured_answer_filtering = structured_answer_filtering
136 self._output_cls = output_cls
137
138 if self._streaming and self._structured_answer_filtering:
139 raise ValueError(
140 "Streaming not supported with structured answer filtering."
141 )
142 if not self._structured_answer_filtering and program_factory is not None:
143 raise ValueError(
144 "Program factory not supported without structured answer filtering."
145 )
# ⭐️ _default_program_factory()を設定
146 self._program_factory = program_factory or self._default_program_factory
147
148 def _get_prompts(self) -> PromptDictType:
149 """Get prompts."""
150 return {
151 "text_qa_template": self._text_qa_template,
152 "refine_template": self._refine_template,
153 }
154
155 def _update_prompts(self, prompts: PromptDictType) -> None:
156 """Update prompts."""
157 if "text_qa_template" in prompts:
158 self._text_qa_template = prompts["text_qa_template"]
159 if "refine_template" in prompts:
160 self._refine_template = prompts["refine_template"]
161
162 @dispatcher.span
# ⭐️ (1) CompactAndRefineクラスのget_response()から呼ばれる
163 def get_response(
164 self,
165 query_str: str,
166 text_chunks: Sequence[str],
167 prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
168 **response_kwargs: Any,
169 ) -> RESPONSE_TEXT_TYPE:
170 """Give response over chunks."""
171 dispatcher.event(
172 GetResponseStartEvent(query_str=query_str, text_chunks=text_chunks)
173 )
174 response: Optional[RESPONSE_TEXT_TYPE] = None
175 for text_chunk in text_chunks:
176 if prev_response is None:
177 # if this is the first chunk, and text chunk already
178 # is an answer, then return it
# ⭐️ (2) self._give_response_single()を呼ぶ
179 response = self._give_response_single(
180 query_str, text_chunk, **response_kwargs
181 )
182 else:
183 # refine response if possible
184 response = self._refine_response_single(
185 prev_response, query_str, text_chunk, **response_kwargs
186 )
187 prev_response = response
188 if isinstance(response, str):
189 if self._output_cls is not None:
190 try:
191 response = self._output_cls.model_validate_json(response)
192 except ValidationError:
193 pass
194 else:
195 response = response or "Empty Response"
196 else:
197 response = cast(Generator, response)
198 dispatcher.event(GetResponseEndEvent())
199 return response
200
# ⭐️ Refineクラスのインスタンス生成時に呼ばれる
201 def _default_program_factory(
202 self, prompt: BasePromptTemplate
203 ) -> BasePydanticProgram:
204 if self._structured_answer_filtering:
205 from llama_index.core.program.utils import get_program_for_llm
206
207 return get_program_for_llm(
208 StructuredRefineResponse,
209 prompt,
210 self._llm,
211 verbose=self._verbose,
212 )
213 else:
# ⭐️ DefaultRefineProgramインスタスを生成し返却
214 return DefaultRefineProgram(
215 prompt=prompt,
216 llm=self._llm,
217 output_cls=self._output_cls,
218 )
219
# ⭐️ (3) _give_response_single()の実行
220 def _give_response_single(
221 self,
222 query_str: str,
223 text_chunk: str,
224 **response_kwargs: Any,
225 ) -> RESPONSE_TEXT_TYPE:
226 """Give response given a query and a corresponding text chunk."""
227 text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
228 text_chunks = self._prompt_helper.repack(
229 text_qa_template, [text_chunk], llm=self._llm
230 )
231
232 response: Optional[RESPONSE_TEXT_TYPE] = None
# ⭐️ (4) pogramファクトリを取得
233 program = self._program_factory(text_qa_template)
234 # TODO: consolidate with loop in get_response_default
235 for cur_text_chunk in text_chunks:
236 query_satisfied = False
237 if response is None and not self._streaming:
# ⭐️responseが作らていない(Refineなしの書)、Streamingでない場
238 try:
239 structured_response = cast(
240 StructuredRefineResponse,
# ⭐️ (5) LLMに問い合わせて結果を取得する
241 program(
242 context_str=cur_text_chunk,
243 **response_kwargs,
244 ),
245 )
246 query_satisfied = structured_response.query_satisfied
# ⭐️ (9) query_satisfiedは常にTrue、結果をresponseに設定
247 if query_satisfied:
248 response = structured_response.answer
249 except ValidationError as e:
250 logger.warning(
251 f"Validation error on structured response: {e}", exc_info=True
252 )
253 elif response is None and self._streaming:
・・・
260 else:
・・・
267 if response is None:
268 response = "Empty Response"
269 if isinstance(response, str):
270 response = response or "Empty Response"
271 else:
272 response = cast(Generator, response)
# ⭐️ (10) responseをリターンして終了
273 return response
・・・
response_synthesizers/refin.pyの全文はこちら
上記273行目のリターンが、呼び出し元の sysnthesize()
に戻ります。
さらに262行目のリターンによりフェーズ2が終了し、query_engine/retriever_query_engine.py
の185行目のreturn response
で④Queryの実行の結果として呼び出しも元のアプリに戻されます。
・・・
199 def synthesize(
200 self,
201 query: QueryTextType,
202 nodes: List[NodeWithScore],
203 additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
204 **response_kwargs: Any,
205 ) -> RESPONSE_TYPE:
・・
241 response_str = self.get_response(
242 query_str=query.query_str,
243 text_chunks=[
244 n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
245 ],
246 **response_kwargs,
247 )
248
249 additional_source_nodes = additional_source_nodes or []
250 source_nodes = list(nodes) + list(additional_source_nodes)
251
# ⭐️_prepare_response_output()は空振り
252 response = self._prepare_response_output(response_str, source_nodes)
253
254 event.on_end(payload={EventPayload.RESPONSE: response})
255
256 dispatcher.event(
257 SynthesizeEndEvent(
258 query=query,
259 response=response,
260 )
261 )
# ⭐️ responseをリターンしフェーズが終了
262 return response
・・・
以上が、QueryEngineでの内部の動きになります。
補足して、下記にSetting.LLMの動きも説明します。
Appendix
Settings.llm = OpenAI(model="gpt-4o", temperature=0.0, )
では、内部でどのようなコードが走っているのでしょうか?
OpenAI()
で、まずはOpenAIクラスのインスタスが生成されます。このインスタンスは指定されたパラメータをインスタンス変数に保持し(指定可能なパラメータは下記コードのArgsを参考にしてください)、LLMと対話を行うためのメソッド(Ex. _complete()' (非Streamingの場合)、
_stream_complete()`(Streamingの場合))を備えています。
この OpenAI クラスは、LlamaIndex の本体 (llama-index-core) ではなく、別パッケージである llama-index-llms-openai に含まれています。
・・・
126 class OpenAI(FunctionCallingLLM):
127 """
128 OpenAI LLM.
129
130 Args:
131 model: name of the OpenAI model to use.
132 temperature: a float from 0 to 1 controlling randomness in generation; higher will lead to more creative, less deterministic responses.
133 max_tokens: the maximum number of tokens to generate.
134 additional_kwargs: Add additional parameters to OpenAI request body.
135 max_retries: How many times to retry the API call if it fails.
136 timeout: How long to wait, in seconds, for an API call before failing.
137 reuse_client: Reuse the OpenAI client between requests. When doing anything with large volumes of async API calls, setting this to false can improve stability.
138 api_key: Your OpenAI api key
139 api_base: The base URL of the API to call
140 api_version: the version of the API to call
141 callback_manager: the callback manager is used for observability.
142 default_headers: override the default headers for API requests.
143 http_client: pass in your own httpx.Client instance.
144 async_http_client: pass in your own httpx.AsyncClient instance.
145
146 Examples:
147 `pip install llama-index-llms-openai`
148
149 ```python
150 import os
151 import openai
152
153 os.environ["OPENAI_API_KEY"] = "sk-..."
154 openai.api_key = os.environ["OPENAI_API_KEY"]
155
156 from llama_index.llms.openai import OpenAI
157
158 llm = OpenAI(model="gpt-3.5-turbo")
159
160 stream = llm.stream("Hi, write a short story")
161
162 for r in stream:
163 print(r.delta, end="")
164 ```
165 """
166
167 model: str = Field(
168 default=DEFAULT_OPENAI_MODEL, description="The OpenAI model to use."
169 )
170 temperature: float = Field(
171 default=DEFAULT_TEMPERATURE,
172 description="The temperature to use during generation.",
173 ge=0.0,
174 le=2.0,
175 )
176 max_tokens: Optional[int] = Field(
177 description="The maximum number of tokens to generate.",
178 gt=0,
179 )
180 logprobs: Optional[bool] = Field(
181 description="Whether to return logprobs per token.",
182 default=None,
183 )
184 top_logprobs: int = Field(
185 description="The number of top token log probs to return.",
186 default=0,
187 ge=0,
188 le=20,
189 )
190 additional_kwargs: Dict[str, Any] = Field(
191 default_factory=dict, description="Additional kwargs for the OpenAI API."
192 )
193 max_retries: int = Field(
194 default=3,
195 description="The maximum number of API retries.",
196 ge=0,
197 )
198 timeout: float = Field(
199 default=60.0,
200 description="The timeout, in seconds, for API requests.",
201 ge=0,
202 )
203 default_headers: Optional[Dict[str, str]] = Field(
204 default=None, description="The default headers for API requests."
205 )
206 reuse_client: bool = Field(
207 default=True,
208 description=(
209 "Reuse the OpenAI client between requests. When doing anything with large "
210 "volumes of async API calls, setting this to false can improve stability."
211 ),
212 )
213
214 api_key: str = Field(default=None, description="The OpenAI API key.")
215 api_base: str = Field(description="The base URL for OpenAI API.")
216 api_version: str = Field(description="The API version for OpenAI API.")
217 strict: bool = Field(
218 default=False,
219 description="Whether to use strict mode for invoking tools/using schemas.",
220 )
221 reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field(
222 default=None,
223 description="The effort to use for reasoning models.",
224 )
225 modalities: Optional[List[str]] = Field(
226 default=None,
227 description="The output modalities to use for the model.",
228 )
229 audio_config: Optional[Dict[str, Any]] = Field(
230 default=None,
231 description="The audio configuration to use for the model.",
232 )
233
234 _client: Optional[SyncOpenAI] = PrivateAttr()
235 _aclient: Optional[AsyncOpenAI] = PrivateAttr()
236 _http_client: Optional[httpx.Client] = PrivateAttr()
237 _async_http_client: Optional[httpx.AsyncClient] = PrivateAttr()
238
239 def __init__(
240 self,
241 model: str = DEFAULT_OPENAI_MODEL,
242 temperature: float = DEFAULT_TEMPERATURE,
243 max_tokens: Optional[int] = None,
244 additional_kwargs: Optional[Dict[str, Any]] = None,
245 max_retries: int = 3,
246 timeout: float = 60.0,
247 reuse_client: bool = True,
248 api_key: Optional[str] = None,
249 api_base: Optional[str] = None,
250 api_version: Optional[str] = None,
251 callback_manager: Optional[CallbackManager] = None,
252 default_headers: Optional[Dict[str, str]] = None,
253 http_client: Optional[httpx.Client] = None,
254 async_http_client: Optional[httpx.AsyncClient] = None,
255 openai_client: Optional[SyncOpenAI] = None,
256 async_openai_client: Optional[AsyncOpenAI] = None,
257 # base class
258 system_prompt: Optional[str] = None,
259 messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
260 completion_to_prompt: Optional[Callable[[str], str]] = None,
261 pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
262 output_parser: Optional[BaseOutputParser] = None,
263 strict: bool = False,
264 reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
265 modalities: Optional[List[str]] = None,
266 audio_config: Optional[Dict[str, Any]] = None,
267 **kwargs: Any,
268 ) -> None:
269 # TODO: Support deprecated max_new_tokens
270 if "max_new_tokens" in kwargs:
271 max_tokens = kwargs["max_new_tokens"]
272 del kwargs["max_new_tokens"]
273
274 additional_kwargs = additional_kwargs or {}
275
276 api_key, api_base, api_version = resolve_openai_credentials(
277 api_key=api_key,
278 api_base=api_base,
279 api_version=api_version,
280 )
281
282 # TODO: Temp forced to 1.0 for o1
283 if model in O1_MODELS:
284 temperature = 1.0
285
286 super().__init__(
287 model=model,
288 temperature=temperature,
289 max_tokens=max_tokens,
290 additional_kwargs=additional_kwargs,
291 max_retries=max_retries,
292 callback_manager=callback_manager,
293 api_key=api_key,
294 api_version=api_version,
295 api_base=api_base,
296 timeout=timeout,
297 reuse_client=reuse_client,
298 default_headers=default_headers,
299 system_prompt=system_prompt,
300 messages_to_prompt=messages_to_prompt,
301 completion_to_prompt=completion_to_prompt,
302 pydantic_program_mode=pydantic_program_mode,
303 output_parser=output_parser,
304 strict=strict,
305 reasoning_effort=reasoning_effort,
306 modalities=modalities,
307 audio_config=audio_config,
308 **kwargs,
309 )
310
311 self._client = openai_client
312 self._aclient = async_openai_client
313 self._http_client = http_client
314 self._async_http_client = async_http_client
・・・
566 @llm_retry_decorator
567 def _complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
568 client = self._get_client()
569 all_kwargs = self._get_model_kwargs(**kwargs)
570 self._update_max_tokens(all_kwargs, prompt)
571
572 if self.reuse_client:
573 response = client.completions.create(
574 prompt=prompt,
575 stream=False,
576 **all_kwargs,
577 )
578 else:
579 with client:
580 response = client.completions.create(
581 prompt=prompt,
582 stream=False,
583 **all_kwargs,
584 )
585 text = response.choices[0].text
586
587 openai_completion_logprobs = response.choices[0].logprobs
588 logprobs = None
589 if openai_completion_logprobs:
590 logprobs = from_openai_completion_logprobs(openai_completion_logprobs)
591
592 return CompletionResponse(
593 text=text,
594 raw=response,
595 logprobs=logprobs,
596 additional_kwargs=self._get_response_token_counts(response),
597 )
598
599 @llm_retry_decorator
600 def _stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
601 client = self._get_client()
602 all_kwargs = self._get_model_kwargs(stream=True, **kwargs)
603 self._update_max_tokens(all_kwargs, prompt)
604
605 def gen() -> CompletionResponseGen:
606 text = ""
607 for response in client.completions.create(
608 prompt=prompt,
609 **all_kwargs,
610 ):
611 if len(response.choices) > 0:
612 delta = response.choices[0].text
613 if delta is None:
614 delta = ""
615 else:
616 delta = ""
617 text += delta
618 yield CompletionResponse(
619 delta=delta,
620 text=text,
621 raw=response,
622 additional_kwargs=self._get_response_token_counts(response),
623 )
624
625 return gen()
・・・
llama-index-llms-openai/llama_index/llms/openai/base.pyの全文はこちら
続いて、生成したOpenAIインスタスをSettingインスタンスの _llm
に設定します。
(248行目の Settings = _Settings()
でSettingはシングルトンで生成されます)
Settings.llmで44行目の@llm.setterが呼ばれ、46行目の resolve_llm()
ではデフォルトでは何もされず指定されたllmインスタンスがリターンされます。
・・・
17 @dataclass
18 class _Settings:
19 """Settings for the Llama Index, lazily initialized."""
20
21 # lazy initialization
22 _llm: Optional[LLM] = None
23 _embed_model: Optional[BaseEmbedding] = None
24 _callback_manager: Optional[CallbackManager] = None
25 _tokenizer: Optional[Callable[[str], List[Any]]] = None
26 _node_parser: Optional[NodeParser] = None
27 _prompt_helper: Optional[PromptHelper] = None
28 _transformations: Optional[List[TransformComponent]] = None
29
30 # ---- LLM ----
31
32 @property
33 def llm(self) -> LLM:
34 """Get the LLM."""
35 if self._llm is None:
36 self._llm = resolve_llm("default")
37
38 if self._callback_manager is not None:
39 self._llm.callback_manager = self._callback_manager
40
41 return self._llm
42
43 @llm.setter
44 def llm(self, llm: LLMType) -> None:
45 """Set the LLM."""
# ⭐️単に指定したllmがリターンされる
46 self._llm = resolve_llm(llm)
47
48 @property
49 def pydantic_program_mode(self) -> PydanticProgramMode:
50 """Get the pydantic program mode."""
51 return self.llm.pydantic_program_mode
52
53 @pydantic_program_mode.setter
54 def pydantic_program_mode(self, pydantic_program_mode: PydanticProgramMode) -> None:
55 """Set the pydantic program mode."""
56 self.llm.pydantic_program_mode = pydantic_program_mode
57
58 # ---- Embedding ----
・・・
76 # ---- Callbacks ----
・・・
106 # ---- Tokenizer ----
・・・
135 # ---- Node parser ----
・・・
185 # ---- Node parser alias ----
・・・
232 # ---- Transformations ----
・・・
247 # Singleton
# ⭐️ Settingsをシングルトンで生成
248 Settings = _Settings()
resolve_llm()の実装は下記になりますが、上述の通り(どの条件も該当しないので) llm
がそのままリターンされます。
・・・
16 def resolve_llm(
17 llm: Optional[LLMType] = None, callback_manager: Optional[CallbackManager] = None
18 ) -> LLM:
19 """Resolve LLM from string or LLM instance."""
20 from llama_index.core.settings import Settings
21
22 try:
23 from langchain.base_language import BaseLanguageModel # pants: no-infer-dep
24 except ImportError:
25 BaseLanguageModel = None # type: ignore
26
# llmはどのじ条件にも該当しないので、単にllmがリターンされる
27 if llm == "default":
・・・
58
59 if isinstance(llm, str):
・・・
86
87 elif BaseLanguageModel is not None and isinstance(llm, BaseLanguageModel):
・・・
98 elif llm is None:
・・・
101
102 assert isinstance(llm, LLM)
103
104 llm.callback_manager = callback_manager or Settings.callback_manager
105
106 return llm
・・・
以上のようにして、Settings.llm にセットされた LLM インスタンスは、LlamaIndex 内部の各モジュール(例:Retriever、Synthesizer、QueryEngine など)から統一的に参照され、LLM への問い合わせ処理に利用されるようになります。
おわりに
コードは意外にシンプルな実装かと思います。次回はAgent(Tools/Function Calling)のコードを解説したいと思います。
記載内容に誤りがあれば継続s修正していきますので、ご指摘をお待ちしています。