こちらのノートブックで日本語のチューニングデータを準備して実行します。
同僚の方が準備されたこちらのデータセットを使います。
from datasets import load_dataset
dataset = load_dataset("yulanfmy/databricks-qa-ja")
dataset.set_format(type="pandas")
df = dataset["train"][:]
training_dataset = (
spark.createDataFrame(df)
.withColumnRenamed("response", "answer")
.withColumnRenamed("instruction", "question")
.withColumnRenamed("source", "url")
.drop("context")
.drop("category")
)
display(training_dataset)
training_dataset.write.saveAsTable("japan_qiita_q_and_a")
上のデータにはコンテンツの情報が入っていないので、自分で準備したデータとjoinしています。
%sql
CREATE TABLE japan_qiita_qa_and_content
AS SELECT
qa.*,
qiita.body as content
FROM
japan_qiita_q_and_a qa
join takaakiyayoi_catalog.qiita_2023.taka_qiita qiita on qa.url = qiita.url
プロンプトも日本語にします。
from pyspark.sql.functions import pandas_udf
import pandas as pd
#base_model_name = "meta-llama/Llama-2-7b-hf"
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
system_prompt = """あなたは高度な知識を持ちプロフェッショナルなDatabricksサポートエージェントです。あなたのゴールは、Databricksに関連する質問や問題を持つユーザーを支援することです。質問に対して可能な限り正確な回答を行い、明確で簡潔な情報を提供します。回答がわからない場合には「わかりません」と回答してください。回答は礼儀正しくプロフェッショナルとして回答してください。Databricksに関連する正確で詳細な情報を提供してください。回答が明確ではない場合、明確化するための質問を行ってください。\n"""
@pandas_udf("array<struct<role:string, content:string>>")
def create_conversation(content: pd.Series, question: pd.Series, answer: pd.Series) -> pd.Series:
def build_message(c,q,a):
user_input = f"こちらが適切と考えられるドキュメントです: {c}。これに基づいて以下の質問に回答してください: {q}"
if "mistral" in base_model_name:
# Mistralはsystemプロンプトをサポートしていません
return [
{"role": "user", "content": f"{system_prompt} \n{user_input}"},
{"role": "assistant", "content": a}]
else:
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input},
{"role": "assistant", "content": a}]
return pd.Series([build_message(c,q,a) for c, q, a in zip(content, question, answer)])
training_data, eval_data = training_dataset.randomSplit([0.9, 0.1], seed=42)
training_data.select(create_conversation("content", "question", "answer").alias('messages')).write.mode('overwrite').saveAsTable("chat_completion_training_dataset")
eval_data.write.mode('overwrite').saveAsTable("chat_completion_evaluation_dataset")
display(spark.table('chat_completion_training_dataset'))
トレーニングデータセットは469件でした。
from databricks.model_training import foundation_model as fm
# データセットを読み取り、ファインチューニングクラスターに送信するために使う現行のクラスターのIDを返却。https://docs.databricks.com/en/large-language-models/foundation-model-training/create-fine-tune-run.html#cluster-id をご覧ください
def get_current_cluster_id():
import json
return json.loads(dbutils.notebook.entry_point.getDbutils().notebook().getContext().safeToJson())['attributes']['clusterId']
# モデル名をきれいにしましょう
registered_model_name = f"{catalog}.{db}." + re.sub(r'[^a-zA-Z0-9]', '_', base_model_name)
run = fm.create(
data_prep_cluster_id=get_current_cluster_id(), # トレーニングデータソースとしてDeltaテーブルを使っている際には必要。これは、データ準備ジョブで使用するクラスターのIDとなります。
model=base_model_name, # ベースラインとしてどのモデルを使うのかを定義
train_data_path=f"{catalog}.{db}.chat_completion_training_dataset",
task_type="CHAT_COMPLETION", # コンプリーションためにファインチューニングAPIを使う際には task_type="INSTRUCTION_FINETUNE" に変更
register_to=registered_model_name,
training_duration="5ep", # デモを加速するために5エポックのみ。この数を増やすかどうかを確認するにはMLflowエクスペリメントのメトリクスをチェックしてください
learning_rate="5e-7",
)
print(run)
前回とは別のエンドポイント、推論テーブルを用いてデプロイします。
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ServedEntityInput, EndpointCoreConfigInput, AutoCaptureConfigInput
serving_endpoint_name = "taka_dbdemos_llm_fine_tuned_jpn" # エンドポイント名
w = WorkspaceClient()
endpoint_config = EndpointCoreConfigInput(
name=serving_endpoint_name,
served_entities=[
ServedEntityInput(
entity_name=registered_model_name,
entity_version=get_latest_model_version(registered_model_name),
min_provisioned_throughput=0, # エンドポイントがスケールダウンする最小秒間トークン数
max_provisioned_throughput=100,# エンドポイントがスケールアップする最大秒間トークン数
scale_to_zero_enabled=True
)
],
auto_capture_config = AutoCaptureConfigInput(catalog_name=catalog, schema_name=db, enabled=True, table_name_prefix="jpn_fine_tuned_llm_inference" )
)
force_update = True # 新規バージョンをリリースする際にはこれを True に設定(このデモではデフォルトで新規モデルバージョンにエンドポイントを更新しません)
existing_endpoint = next(
(e for e in w.serving_endpoints.list() if e.name == serving_endpoint_name), None
)
if existing_endpoint == None:
print(f"Creating the endpoint {serving_endpoint_name}, this will take a few minutes to package and deploy the endpoint...")
w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config)
else:
print(f"endpoint {serving_endpoint_name} already exist...")
if force_update:
w.serving_endpoints.update_config_and_wait(served_entities=endpoint_config.served_entities, name=serving_endpoint_name)
モデルサービングエンドポイントがデプロイされたら動作確認します。
import mlflow
from mlflow import deployments
# system + userロールのみを取得する様に回答を削除
test_dataset = spark.table('chat_completion_training_dataset').selectExpr("slice(messages, 1, size(messages)-1) as messages").limit(1)
# 最初のメッセージの取得
messages = test_dataset.toPandas().iloc[0].to_dict()['messages'].tolist()
client = mlflow.deployments.get_deploy_client("databricks")
client.predict(endpoint=serving_endpoint_name, inputs={"messages": messages, "max_tokens": 100})
ここでのmessage
はノートブックのテストの方法を教えてください。 です。
レスポンスは以下の通りとなっています。
{'id': 'chatcmpl-d9489e9788aa4d0aba6b99eace997b3c',
'object': 'chat.completion',
'created': 1717394866,
'choices': [{'index': 0,
'message': {'role': 'assistant',
'content': ' 多くのユニットテストのライブラリはノートブック内で直接動作します。例えば、ノートブックのコードをテストするためにビルトインのPythonの[\\`unittest\\`](https://docs.python.org/3/library/unittest.html)パッケージを使用する'},
'finish_reason': 'length',
'logprobs': None}],
'usage': {'prompt_tokens': 2029,
'completion_tokens': 100,
'total_tokens': 2129}}
それなりの回答が返ってきました。
AI Playgroundで比較してみます。左が日本語データセットでファインチューニングしたもの、右が英語データセットでファインチューニングしたものです。
人の目で見ても、右の英語データセットでファインチューニングしたものは若干トンチンカンな回答になっていますが、左の日本語データセット(Qiita記事をベースとした日本語Q&A)の方が適切なアウトプットを生成しています。
しかし、こんなにお手軽にファインチューニングできるのは嬉しいです。いろいろ試行錯誤できます。