1
0

DatabricksのファインチューニングAPIを試してみる(日本語編)

Last updated at Posted at 2024-06-03

こちらのノートブックで日本語のチューニングデータを準備して実行します。

同僚の方が準備されたこちらのデータセットを使います。

from datasets import load_dataset
dataset = load_dataset("yulanfmy/databricks-qa-ja")
dataset.set_format(type="pandas")

df = dataset["train"][:]

training_dataset = (
    spark.createDataFrame(df)
    .withColumnRenamed("response", "answer")
    .withColumnRenamed("instruction", "question")
    .withColumnRenamed("source", "url")
    .drop("context")
    .drop("category")
)
display(training_dataset)

training_dataset.write.saveAsTable("japan_qiita_q_and_a")

Screenshot 2024-06-03 at 14.38.34.png

上のデータにはコンテンツの情報が入っていないので、自分で準備したデータとjoinしています。

%sql
CREATE TABLE japan_qiita_qa_and_content
AS SELECT
  qa.*,
  qiita.body as content
FROM
  japan_qiita_q_and_a qa
  join takaakiyayoi_catalog.qiita_2023.taka_qiita qiita on qa.url = qiita.url

Screenshot 2024-06-03 at 14.39.27.png

プロンプトも日本語にします。

from pyspark.sql.functions import pandas_udf
import pandas as pd

#base_model_name = "meta-llama/Llama-2-7b-hf"
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"

system_prompt = """あなたは高度な知識を持ちプロフェッショナルなDatabricksサポートエージェントです。あなたのゴールは、Databricksに関連する質問や問題を持つユーザーを支援することです。質問に対して可能な限り正確な回答を行い、明確で簡潔な情報を提供します。回答がわからない場合には「わかりません」と回答してください。回答は礼儀正しくプロフェッショナルとして回答してください。Databricksに関連する正確で詳細な情報を提供してください。回答が明確ではない場合、明確化するための質問を行ってください。\n"""

@pandas_udf("array<struct<role:string, content:string>>")
def create_conversation(content: pd.Series, question: pd.Series, answer: pd.Series) -> pd.Series:
    def build_message(c,q,a):
        user_input = f"こちらが適切と考えられるドキュメントです: {c}。これに基づいて以下の質問に回答してください: {q}"
        if "mistral" in base_model_name:
            # Mistralはsystemプロンプトをサポートしていません
            return [
                {"role": "user", "content": f"{system_prompt} \n{user_input}"},
                {"role": "assistant", "content": a}]
        else:
            return [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_input},
                {"role": "assistant", "content": a}]
    return pd.Series([build_message(c,q,a) for c, q, a in zip(content, question, answer)])


training_data, eval_data = training_dataset.randomSplit([0.9, 0.1], seed=42)

training_data.select(create_conversation("content", "question", "answer").alias('messages')).write.mode('overwrite').saveAsTable("chat_completion_training_dataset")
eval_data.write.mode('overwrite').saveAsTable("chat_completion_evaluation_dataset")

display(spark.table('chat_completion_training_dataset'))

トレーニングデータセットは469件でした。

Screenshot 2024-06-03 at 14.57.05.png

from databricks.model_training import foundation_model as fm
# データセットを読み取り、ファインチューニングクラスターに送信するために使う現行のクラスターのIDを返却。https://docs.databricks.com/en/large-language-models/foundation-model-training/create-fine-tune-run.html#cluster-id をご覧ください
def get_current_cluster_id():
  import json
  return json.loads(dbutils.notebook.entry_point.getDbutils().notebook().getContext().safeToJson())['attributes']['clusterId']


# モデル名をきれいにしましょう
registered_model_name = f"{catalog}.{db}." + re.sub(r'[^a-zA-Z0-9]', '_',  base_model_name)

run = fm.create(
    data_prep_cluster_id=get_current_cluster_id(),  # トレーニングデータソースとしてDeltaテーブルを使っている際には必要。これは、データ準備ジョブで使用するクラスターのIDとなります。
    model=base_model_name,  # ベースラインとしてどのモデルを使うのかを定義
    train_data_path=f"{catalog}.{db}.chat_completion_training_dataset",
    task_type="CHAT_COMPLETION",  # コンプリーションためにファインチューニングAPIを使う際には task_type="INSTRUCTION_FINETUNE" に変更
    register_to=registered_model_name,
    training_duration="5ep", # デモを加速するために5エポックのみ。この数を増やすかどうかを確認するにはMLflowエクスペリメントのメトリクスをチェックしてください
    learning_rate="5e-7",
)

print(run)

日本語データセットでもlossが減少していっています。
Screenshot 2024-06-03 at 14.55.41.png

前回とは別のエンドポイント、推論テーブルを用いてデプロイします。

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ServedEntityInput, EndpointCoreConfigInput, AutoCaptureConfigInput

serving_endpoint_name = "taka_dbdemos_llm_fine_tuned_jpn" # エンドポイント名
w = WorkspaceClient()
endpoint_config = EndpointCoreConfigInput(
    name=serving_endpoint_name,
    served_entities=[
        ServedEntityInput(
            entity_name=registered_model_name,
            entity_version=get_latest_model_version(registered_model_name),
            min_provisioned_throughput=0, # エンドポイントがスケールダウンする最小秒間トークン数
            max_provisioned_throughput=100,# エンドポイントがスケールアップする最大秒間トークン数
            scale_to_zero_enabled=True
        )
    ],
    auto_capture_config = AutoCaptureConfigInput(catalog_name=catalog, schema_name=db, enabled=True, table_name_prefix="jpn_fine_tuned_llm_inference" )
)

force_update = True # 新規バージョンをリリースする際にはこれを True に設定(このデモではデフォルトで新規モデルバージョンにエンドポイントを更新しません)
existing_endpoint = next(
    (e for e in w.serving_endpoints.list() if e.name == serving_endpoint_name), None
)
if existing_endpoint == None:
    print(f"Creating the endpoint {serving_endpoint_name}, this will take a few minutes to package and deploy the endpoint...")
    w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config)
else:
  print(f"endpoint {serving_endpoint_name} already exist...")
  if force_update:
    w.serving_endpoints.update_config_and_wait(served_entities=endpoint_config.served_entities, name=serving_endpoint_name)

モデルサービングエンドポイントがデプロイされたら動作確認します。

import mlflow
from mlflow import deployments
# system + userロールのみを取得する様に回答を削除
test_dataset = spark.table('chat_completion_training_dataset').selectExpr("slice(messages, 1, size(messages)-1) as messages").limit(1)
# 最初のメッセージの取得
messages = test_dataset.toPandas().iloc[0].to_dict()['messages'].tolist()

client = mlflow.deployments.get_deploy_client("databricks")
client.predict(endpoint=serving_endpoint_name, inputs={"messages": messages, "max_tokens": 100})

ここでのmessageノートブックのテストの方法を教えてください。 です。

レスポンスは以下の通りとなっています。

{'id': 'chatcmpl-d9489e9788aa4d0aba6b99eace997b3c',
 'object': 'chat.completion',
 'created': 1717394866,
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': ' 多くのユニットテストのライブラリはノートブック内で直接動作します。例えば、ノートブックのコードをテストするためにビルトインのPythonの[\\`unittest\\`](https://docs.python.org/3/library/unittest.html)パッケージを使用する'},
   'finish_reason': 'length',
   'logprobs': None}],
 'usage': {'prompt_tokens': 2029,
  'completion_tokens': 100,
  'total_tokens': 2129}}

それなりの回答が返ってきました。

AI Playgroundで比較してみます。左が日本語データセットでファインチューニングしたもの、右が英語データセットでファインチューニングしたものです。
Screenshot 2024-06-03 at 15.19.47.png
Screenshot 2024-06-03 at 15.24.59.png

人の目で見ても、右の英語データセットでファインチューニングしたものは若干トンチンカンな回答になっていますが、左の日本語データセット(Qiita記事をベースとした日本語Q&A)の方が適切なアウトプットを生成しています。

しかし、こんなにお手軽にファインチューニングできるのは嬉しいです。いろいろ試行錯誤できます。

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0