1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LangChainのPySparkデータフレームローダーとMLflowの活用

Posted at

こちらのPySpark DataFrame loader and MLFlow in Langchain ノートブックをウォークスルーします。

このノートブックでは、PySparkとLangchainのインテグレーションを説明し、以下の方法を説明します:

  1. PySparkデータフレームからのLangchainドキュメントローダーの作成
  2. ドキュメントローダーを用いたLangchain RetrievalQAインスタンスの作成
  3. サンプルのRetrievalQAを保存するためにMLflowを活用

要件

  • Databricks Runtime 13.3 ML以降
  • MLflow 2.5以降

インポート

%pip install --upgrade langchain faiss-cpu mlflow

# GPUクラスターでは以下を使います 
# %pip install --upgrade langchain faiss-gpu mlflow
dbutils.library.restartPython()

PySparkベースのドキュメントロードによるRetrievalQAチェーンの作成

/databricks-datasets/内のWikipediaデータセットを使いましょう。以下のセルにお使いのOpenAIのAPIキーを追加してください。

import os

os.environ["OPENAI_API_KEY"] = "OpenAI APIキー"
number_of_articles = 20

wikipedia_dataframe = spark.read.parquet("databricks-datasets/wikipedia-datasets/data-001/en_wikipedia/articles-only-parquet/*").limit(number_of_articles)
display(wikipedia_dataframe)

Screenshot 2023-11-30 at 15.29.47.png

以下の行が、PySparkデータフレームからLangchainにデータをロードするために必要なことの全てです。

from langchain.document_loaders import PySparkDataFrameLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

loader = PySparkDataFrameLoader(spark, wikipedia_dataframe, page_content_column="text")
documents = loader.load()

# 再帰的にテキストを分割
text_splitter = RecursiveCharacterTextSplitter(chunk_size=3000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
print(f"Number of documents: {len(texts)}")
Number of documents: 421

HuggingfaceEmbeddingsを用いたFAISSベクトルストアの作成

FAISSベクトルストアは、MLflowでモデルを記録できるようにするための中間ステップとなります。

from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

embeddings = OpenAIEmbeddings()
db = FAISS.from_documents(texts, embeddings)

RetrievalQAチェーンの作成

from langchain.chains import RetrievalQA
from langchain import OpenAI

retrieval_qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=db.as_retriever())

RetrievalQAチェーンへのクエリー

query = "Who is Harrison Schmitt"
result = retrieval_qa({"query": query})
print("Result:", result["result"])
Result:  Harrison Hagan "Jack" Schmitt is an American geologist, retired NASA astronaut, university professor and former U.S. Senator from New Mexico. In 1972, he was a member of the crew on board Apollo 17 and became the twelfth person to set foot on the Moon. He resigned from NASA in 1975 to run for election to the United States Senate.

Mlflowによるチェーンのロギング

import mlflow

persist_directory = "langchain/faiss_index"
db.save_local(persist_directory)

def load_retriever(persist_directory):
  embeddings = OpenAIEmbeddings()
  db = FAISS.load_local(persist_directory, embeddings)
  return db.as_retriever()

# RetrievalQAチェーンの記録
with mlflow.start_run() as mlflow_run:
  logged_model = mlflow.langchain.log_model(
    retrieval_qa,
    "retrieval_qa_chain",
    loader_fn=load_retriever,
    persist_dir=persist_directory,
  )

記録されました。
Screenshot 2023-11-30 at 15.32.48.png

MLFlowを用いたチェーンのロード

他の機械学習モデルと同じようにロードできます。

model_uri = f"runs:/{ mlflow_run.info.run_id }/retrieval_qa_chain"

loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri)
langchain_input = {"query": "Who is Harrison Schmitt"}
loaded_pyfunc_model.predict([langchain_input])
[' Harrison Schmitt is an American geologist, retired NASA astronaut, university professor and former U.S. senator from New Mexico. He was the 12th person to set foot on the Moon and the second-to-last person to step off of the Moon.']

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?