1
0

Databricks VectorSearchやFoundation Models APIを用いたRAGアプリケーションの構築

Posted at

久々にチュートリアル(旧dbdemos)を見たらアップデートされてましたのでウォークスルーします。

注意
このノートブックを実行するにはVectorSearchFoundation Models APIが必要です。

こちらのノートブックでは、以下のDatabricksの機能を活用しています。

  • VectorSearch: ソーステーブルの更新に応じて自動で同期されるベクトルDBです。
  • Foundation Model API: 様々なLLMを統一されたAPIでアクセスできるので、用途に合わせて様々なLLMを組み合わせることができます。
  • モデルサービングエンドポイント: 構築したLLMモデルをREST API経由で呼び出せるのでアプリケーションの構築を加速します。

環境構築

小規模なクラスターにノートブックをアタッチしてリソースをダウンロードします。

%pip install dbdemos

データを格納するカタログとスキーマ(データベース)を指定できるようになってました。指定するカタログやスキーマを先に作成しておきます。

import dbdemos
dbdemos.install('llm-rag-chatbot', catalog='takaakiyayoi_catalog', schema='rag_chatbot')

クラスターやダッシュボードも作成されています。
Screenshot 2024-02-01 at 20.54.15.png

以降は作成されたクラスターでノートブックを実行していきます。

config

最初にconfigノートブックを変更します。catalogdbNameは上のコマンドで指定したものが入っていますが、ここではDATABRICKS_SITEMAP_URLを日本のサイトに変更します。

VECTOR_SEARCH_ENDPOINT_NAME="dbdemos_vs_endpoint"

DATABRICKS_SITEMAP_URL = "https://docs.databricks.com/ja/doc-sitemap.xml"

catalog = "takaakiyayoi_catalog"

#email = spark.sql('select current_user() as user').collect()[0]['user']
#username = email.split('@')[0].replace('.', '_')
#dbName = db = f"rag_chatbot_{username}"
dbName = db = "rag_chatbot"

_resources/00-init

次に_resources/00-initを修正します。そのまま実行するとNameError: name 'catalog' is not definedというエラーになります。

%pip install mlflow==2.9.0 lxml==4.9.3 transformers==4.30.2 langchain==0.0.344 databricks-vectorsearch==0.22
dbutils.library.restartPython()

の後で

%run ../config

が実行されるようにします。これが逆になっていたのでdbutils.library.restartPython()で変数がクリアされてました。issueも上がっているのでそのうち修正されると思います。

01-Data-Preparation-and-Index

01-Data-Preparation-and-Indexを実行していきます。データを準備し、ベクトルインデックスを作成します。

%pip install mlflow==2.9.0 lxml==4.9.3 transformers==4.30.2 langchain==0.0.344 databricks-vectorsearch==0.22
dbutils.library.restartPython()
%run ../_resources/00-init $reset_all_data=false

マニュアルサイトからデータを取得します。

if not table_exists("raw_documentation") or spark.table("raw_documentation").isEmpty():
    # Download Databricks documentation to a DataFrame (see _resources/00-init for more details)
    doc_articles = download_databricks_documentation_articles()
    #Save them as a raw_documentation table
    doc_articles.write.mode('overwrite').saveAsTable("raw_documentation")

display(spark.table("raw_documentation").limit(2))

Screenshot 2024-02-01 at 21.00.38.png

長いドキュメントをチャンクに分割します。

from langchain.text_splitter import HTMLHeaderTextSplitter, RecursiveCharacterTextSplitter
from transformers import AutoTokenizer

max_chunk_size = 500

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(tokenizer, chunk_size=max_chunk_size, chunk_overlap=50)
html_splitter = HTMLHeaderTextSplitter(headers_to_split_on=[("h2", "header2")])

# Split on H2, but merge small h2 chunks together to avoid too small. 
def split_html_on_h2(html, min_chunk_size = 20, max_chunk_size=500):
  if not html:
      return []
  h2_chunks = html_splitter.split_text(html)
  chunks = []
  previous_chunk = ""
  # Merge chunks together to add text before h2 and avoid too small docs.
  for c in h2_chunks:
    # Concat the h2 (note: we could remove the previous chunk to avoid duplicate h2)
    content = c.metadata.get('header2', "") + "\n" + c.page_content
    if len(tokenizer.encode(previous_chunk + content)) <= max_chunk_size/2:
        previous_chunk += content + "\n"
    else:
        chunks.extend(text_splitter.split_text(previous_chunk.strip()))
        previous_chunk = content + "\n"
  if previous_chunk:
      chunks.extend(text_splitter.split_text(previous_chunk.strip()))
  # Discard too small chunks
  return [c for c in chunks if len(tokenizer.encode(c)) > min_chunk_size]
 
# Let's try our chunking function
html = spark.table("raw_documentation").limit(1).collect()[0]['text']
split_html_on_h2(html)
['ワークスペースの外観設定を管理する  \nワークスペース管理者は、ビジュアライゼーションで使用されるカラー パレットから日付と時刻の形式まで、ワークスペースの外観に関連するさまざまな設定を管理できます。  \n外観設定を見つける\n外観設定を見つける\nワークスペースの外観の設定は、[ 外観 ]タブの 管理者設定 にあります。\n\n使用可能な外観設定',
 '使用可能な外観設定\n使用可能な設定には、次のものがあります。  \nPlotly: この設定は、Plotlyツールバーをグラフのビジュアリゼーションに表示するかどうかを決定します。  \nカラーパレット: ダッシュボードでビジュアライゼーション用のカスタムカラーパレットを作成します。 すべてのワークスペース ユーザーは、必要に応じて、個々のダッシュボードにこのカスタム パレットをインポートできます。  \n言語: クエリーの名前と説明のマルチバイト(中国語、日本語、韓国語)検索を有効にするかどうか。 有効にすると、検索が遅くなります。  \n日付形式: クエリーの視覚化で使用されるデフォルトの日付と時刻の形式。 個々のクエリーの視覚化の日付と時刻の形式を変更するには、「 一般的なデータ型」を参照してください。',
 '時間形式: クエリーの視覚化で使用されるデフォルトの時間形式。 個々のクエリーの視覚化の日付と時刻の形式を変更するには、「 一般的なデータ型」を参照してください。']

ドキュメントを格納するテーブルを作成し、delta.enableChangeDataFeed = trueでチェンジデータキャプチャをオンにします。これはVectorSearchのソーステーブルとして差分を追加できるようにするためです。これによってVectorSearchのインデックスはソーステーブルの更新に合わせて自動で同期されるようになります。

%sql
--Note that we need to enable Change Data Feed on the table to create the index
CREATE TABLE IF NOT EXISTS databricks_documentation (
  id BIGINT GENERATED BY DEFAULT AS IDENTITY,
  url STRING,
  content STRING
) TBLPROPERTIES (delta.enableChangeDataFeed = true); 
# Let's create a user-defined function (UDF) to chunk all our documents with spark
@pandas_udf("array<string>")
def parse_and_split(docs: pd.Series) -> pd.Series:
    return docs.apply(split_html_on_h2)
    
(spark.table("raw_documentation")
      .filter('text is not null')
      .withColumn('content', F.explode(parse_and_split('text')))
      .drop("text")
      .write.mode('overwrite').saveAsTable("databricks_documentation"))

display(spark.table("databricks_documentation"))

Screenshot 2024-02-01 at 21.03.45.png

エンベディングを作成するためのエンベディングエンドポイントを呼び出して動作確認します。

import mlflow.deployments
deploy_client = mlflow.deployments.get_deploy_client("databricks")

#Embeddings endpoints convert text into a vector (array of float). Here is an example using BGE:
response = deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": ["What is Apache Spark?"]})
embeddings = [e['embedding'] for e in response.data]
print(embeddings)
[[0.0186004638671875, -0.0141448974609375, -0.0574951171875, 0.0034027099609375, 0.008453369140625, -0.0216064453125, -0.02471923828125, -0.004688262939453125, 0.0136566162109375, 0.050384521484375, -0.0272064208984375, -0.01470184326171875, 0.054718017578125, -0.0538330078125, -0.01035308837890625, -0.0162200927734375, -0.0188140869140625, -0.017242431640625, -0.051300048828125, 0.0177764892578125, 0.00434112548828125, 0.0284423828125, -0.055633544921875, -0.037689208984375, -0.001373291015625, 0.0203704833984375, -0.046661376953125, 0.01580810546875, 0.0938720703125, 0.0195770263671875, -0.044647216796875, -0.0124359130859375, -0.0062255859375, ...

VectorSearchのエンドポイントを作成します。

from databricks.vector_search.client import VectorSearchClient
vsc = VectorSearchClient()

if VECTOR_SEARCH_ENDPOINT_NAME not in [e['name'] for e in vsc.list_endpoints().get('endpoints', [])]:
    vsc.create_endpoint(name=VECTOR_SEARCH_ENDPOINT_NAME, endpoint_type="STANDARD")

wait_for_vs_endpoint_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME)
print(f"Endpoint named {VECTOR_SEARCH_ENDPOINT_NAME} is ready.")

ソーステーブルとインデックスの格納テーブルを指定します。

from databricks.sdk import WorkspaceClient
import databricks.sdk.service.catalog as c

#The table we'd like to index
source_table_fullname = f"{catalog}.{db}.databricks_documentation"
# Where we want to store our index
vs_index_fullname = f"{catalog}.{db}.databricks_documentation_vs_index"

if not index_exists(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname):
  print(f"Creating index {vs_index_fullname} on endpoint {VECTOR_SEARCH_ENDPOINT_NAME}...")
  vsc.create_delta_sync_index(
    endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME,
    index_name=vs_index_fullname,
    source_table_name=source_table_fullname,
    pipeline_type="TRIGGERED",
    primary_key="id",
    embedding_source_column='content', #The column containing our text
    embedding_model_endpoint_name='databricks-bge-large-en' #The embedding endpoint used to create the embeddings
  )
else:
  #Trigger a sync to update our vs content with the new data saved in the table
  vsc.get_index(VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname).sync()

#Let's wait for the index to be ready and all our embeddings to be created and indexed
wait_for_index_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname)
print(f"index {vs_index_fullname} on table {source_table_fullname} is ready")

カタログエクスプローラで処理の進捗を確認できます。
Screenshot 2024-02-01 at 21.05.50.png

完了しました。
Screenshot 2024-02-01 at 21.10.08.png

動作確認します。エンべディングによる類似検索を行います。

import mlflow.deployments
deploy_client = mlflow.deployments.get_deploy_client("databricks")

question = "機械学習モデルはどのように管理されますか?"

results = vsc.get_index(VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname).similarity_search(
  query_text=question,
  columns=["url", "content"],
  num_results=1)
docs = results.get('result', {}).get('data_array', [])
docs

とりあえず動いてます。

[['https://docs.databricks.com/ja/getting-started/concepts.html',
  'モデルとモデルレジストリ  \nモデルレジストリに登録されているトレーニング済みの機械学習またはディープラーニングモデルを指します。',
  0.76031744]]

02-Deploy-RAG-Chatbot-Model

チャットbotとしてRAGをデプロイしていきます。

%pip install mlflow==2.9.0 langchain==0.0.344 databricks-vectorsearch==0.22 databricks-sdk==0.12.0 mlflow[databricks]
dbutils.library.restartPython()
%run ../_resources/00-init $reset_all_data=false

上で構築したベクトルサーチインデックスのためのエンドポイントにアクセスするには、パーソナルアクセストークンの設定されたシークレットが必要です。私はシークレットスコープdemo-token-takaaki.yayoiにキーtokenとしてパーソナルアクセストークンを設定しています。

index_name=f"{catalog}.{db}.databricks_documentation_vs_index"
host = "https://" + spark.conf.get("spark.databricks.workspaceUrl")

test_demo_permissions(host, secret_scope="demo-token-takaaki.yayoi", secret_key="token", vs_endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME, index_name=index_name, embedding_endpoint_name="databricks-bge-large-en")

Langchain retrieverを使って動作確認します。

# url used to send the request to your model from the serverless endpoint
host = "https://" + spark.conf.get("spark.databricks.workspaceUrl")
os.environ['DATABRICKS_TOKEN'] = dbutils.secrets.get("demo-token-takaaki.yayoi", "token")
from databricks.vector_search.client import VectorSearchClient
from langchain.vectorstores import DatabricksVectorSearch
from langchain.embeddings import DatabricksEmbeddings

# Test embedding Langchain model
#NOTE: your question embedding model must match the one used in the chunk in the previous model 
embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
print(f"Test embeddings: {embedding_model.embed_query('Apache Sparkとは何ですか?')[:20]}...")

def get_retriever(persist_dir: str = None):
    os.environ["DATABRICKS_HOST"] = host
    #Get the vector search index
    vsc = VectorSearchClient(workspace_url=host, personal_access_token=os.environ["DATABRICKS_TOKEN"])
    vs_index = vsc.get_index(
        endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME,
        index_name=index_name
    )

    # Create the retriever
    vectorstore = DatabricksVectorSearch(
        vs_index, text_column="content", embedding=embedding_model
    )
    return vectorstore.as_retriever()


# test our retriever
vectorstore = get_retriever()
similar_documents = vectorstore.get_relevant_documents("料金を確認するにはどうすればいいですか?")
print(f"Relevant documents: {similar_documents[0]}")

動いているようです。

Test embeddings: [0.00991058349609375, -0.0022792816162109375, -0.0133056640625, -0.016845703125, -0.007175445556640625, -0.01067352294921875, 0.01282501220703125, 0.0014181137084960938, 0.014129638671875, 0.072265625, -0.0168914794921875, -0.028656005859375, 0.036376953125, -0.04656982421875, -0.035552978515625, -0.0323486328125, 0.0033550262451171875, -0.0253448486328125, -0.057403564453125, 0.049530029296875]...
[NOTICE] Using a Personal Authentication Token (PAT). Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True to VectorSearchClient().
WARNING:langchain.vectorstores.databricks_vector_search:embedding model is not used in delta-sync index with Databricks-managed embeddings.
Relevant documents: page_content='課金ポータルから使用状況を表示する' metadata={'id': 5657.0}

次に基盤モデルとしてLlama2を使います。DatabricksではFoundation Models APIを用いることで、簡単にLLMにアクセスできます。

# Test Databricks Foundation LLM model
from langchain.chat_models import ChatDatabricks
chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens = 200)
print(f"Test chat model: {chat_model.predict('Apache Sparkとは何ですか?')}")

こちらも動いています。

Test chat model: 
Apache Sparkは、大規模データの処理を行うためのオープンソースのデータプロセッシングエンジンです。Sparkは、Hadoopの上で動作するように設計されており、Hadoopのような分散型ストレージシステムに保存されたデータを処理することができます。Sparkは、メモリ上でデータを処理することにより、Hadoopのような disk-based システムよりもはるかに高速な処理が可能になります。

Sparkには、次のような特徴がありま

RAGチェーンを組み立てます。

from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatDatabricks

TEMPLATE = """あなたはDatabricksユーザーに対するアシスタントです。あなたは、Databricksに関係するPython、コード、SQL、データエンジニアリング、Spark、データサイエンス、データウェアハウスやプラットフォーム、APIやインフラストラクチャ管理の質問に日本語で回答します。これらのトピックに関係ない質問の場合には、優しく回答を拒否してください。回答がわからない場合にはわからないと回答し、答えを作り出そうとはしないでください。可能な限り簡潔な回答にしてください。
質問に回答するには最後にあるコンテキストの箇所を使ってください:
{context}
質問: {question}
回答:
"""
prompt = PromptTemplate(template=TEMPLATE, input_variables=["context", "question"])

chain = RetrievalQA.from_chain_type(
    llm=chat_model,
    chain_type="stuff",
    retriever=get_retriever(),
    chain_type_kwargs={"prompt": prompt}
)

動作確認します。

# langchain.debug = True #uncomment to see the chain details and the full prompt being sent
question = {"query": "MLflowにはどのような機能がありますか?"}
answer = chain.run(question)
print(answer)
  MLflowには、以下のような機能があります。

1. 実験管理: MLflowでは、機械学習実験を管理することができます。実験のメタデータ、入力データ、出力データ、および実験の状態をtrackingすることができます。
2. モデル管理: MLflowでは、機械学習モデルを管理することができます。モデルのメタデータ、モデルの構成、およびモデルのパフォーマンスをtrackingすることができます。
3. ワ

これでモデルが出来上がったのでUnity Catalogに保存します。

from mlflow.models import infer_signature
import mlflow
import langchain

mlflow.set_registry_uri("databricks-uc")
model_name = f"{catalog}.{db}.dbdemos_chatbot_model"

with mlflow.start_run(run_name="dbdemos_chatbot_rag") as run:
    signature = infer_signature(question, answer)
    model_info = mlflow.langchain.log_model(
        chain,
        loader_fn=get_retriever,  # Load the retriever with DATABRICKS_TOKEN env as secret (for authentication).
        artifact_path="chain",
        registered_model_name=model_name,
        pip_requirements=[
            "mlflow==" + mlflow.__version__,
            "langchain==" + langchain.__version__,
            "databricks-vectorsearch",
        ],
        input_example=question,
        signature=signature
    )

保存されました。私は前のバージョンがあったのでバージョン3になってます。
Screenshot 2024-02-01 at 21.33.51.png

チャットモデルをサービングエンドポイントにデプロイします。

# Create or update serving endpoint
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedModelInput

serving_endpoint_name = f"dbdemos_endpoint_{catalog}_{db}"[:63]
latest_model_version = get_latest_model_version(model_name)

w = WorkspaceClient()
endpoint_config = EndpointCoreConfigInput(
    name=serving_endpoint_name,
    served_models=[
        ServedModelInput(
            model_name=model_name,
            model_version=latest_model_version,
            workload_size="Small",
            scale_to_zero_enabled=True,
            environment_vars={
                "DATABRICKS_TOKEN": "{{secrets/demo-token-takaaki.yayoi/token}}",  # <scope>/<secret> that contains an access token
            }
        )
    ]
)

existing_endpoint = next(
    (e for e in w.serving_endpoints.list() if e.name == serving_endpoint_name), None
)
serving_endpoint_url = f"{host}/ml/endpoints/{serving_endpoint_name}"
if existing_endpoint == None:
    print(f"Creating the endpoint {serving_endpoint_url}, this will take a few minutes to package and deploy the endpoint...")
    w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config)
else:
    print(f"Updating the endpoint {serving_endpoint_url} to version {latest_model_version}, this will take a few minutes to package and deploy the endpoint...")
    w.serving_endpoints.update_config_and_wait(served_models=endpoint_config.served_models, name=serving_endpoint_name)
    
displayHTML(f'Your Model Endpoint Serving is now available. Open the <a href="/ml/endpoints/{serving_endpoint_name}">Model Serving Endpoint page</a> for more details.')

これによってREST API経由でモデルを呼び出せるようになります。
Screenshot 2024-02-01 at 21.46.34.png

モデルサービングエンドポイントが稼働したら動作確認します。

question = "MLflowにはどのような機能がありますか?"

answer = w.serving_endpoints.query(serving_endpoint_name, inputs=[{"query": question}])
print(answer.predictions[0])

動きました!エンべディングモデルやトークン数の調整など、まだ微調整は必要ですが一通りウォークスルーできました。

  MLflow は、機械学習 (ML) モデルのライフサイクル管理を支援するためのオープンソースのプラットフォームです。以下は、MLflow によって提供される機能の一部です。

1. モデルの登録と管理: MLflow では、モデルを登録し、バージョンを管理することができます。これにより、モデルの変更履歴を追跡し、適切なバージョンを使用することができます。
2. モデルの実行と

あとは画面との繋ぎ込みをすればチャットbotアプリケーションの出来上がりです。そちらは日を改めて。

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0