0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Mosaic AI Model TrainingによるLLMのファインチューニング

Posted at

こちらの続編です。

以前から更新があり、テキスト分類やNER(固有表現抽出)のノートブックが追加され構成も変更されています。APIの名称も、Mosaic AI Model Training APIに変更になっています。

新たな01-classification-fine-tuning-customer-supportノートブックをウォークスルーします。

注意
Mosaic AI Model Trainingが利用できるリージョンは限られています。執筆時点では日本リージョンでは利用できません。価格はこちらをご覧ください。

Databricks Mosaic AIを使用したLLMのファインチューニング

LLMのファインチューニングは、既存のモデルを自分の要件に特化させる行為です。
技術的には、既存のファウンデーションモデル(DBRXやLLAMAなど)の重みから始め、自分のデータセットに基づいて別のトレーニングラウンドを追加します。

ファインチューニングの理由

Databricksは、既存のファウンデーションモデルを特化させる簡単な方法を提供し、モデルを所有および制御しながら、パフォーマンスの向上、コストの削減、セキュリティとプライバシーの強化を実現します。

典型的なファインチューニングのユースケースは以下の通りです:

  • 特定の内部知識に基づいてモデルをトレーニングする
  • モデルの振る舞いをカスタマイズする、例:特化したエンティティ抽出
  • モデルサイズと推論コストを削減しながら、回答品質を向上させる

継続的な事前トレーニングと指示に基づくファインチューニングの違いは?

DatabricksファインチューニングAPIを使用すると、モデルをいくつかの異なる方法で適応させることができます:

  • 教師ありファインチューニング:構造化されたプロンプト-レスポンスデータでモデルをトレーニングします。これを使用して、モデルを新しいタスクに適応させたり、応答スタイルを変更したり、指示に従う能力を追加したりします。
  • 継続的な事前トレーニング:追加のラベルなしテキストデータでモデルをトレーニングします。これを使用して、モデルに新しい知識を追加したり、モデルを特定のドメインに集中させたりします。関連性を持たせるには数百万のトークンが必要です。
  • チャット完了:ユーザーとAIアシスタント間のチャットログでモデルをトレーニングします。この形式は、実際のチャットログだけでなく、質問応答や会話テキストの標準形式としても使用できます。テキストは、特定のモデルに適したチャット形式に自動的にフォーマットされます。

チャット完了APIを使用します。これを使用すると、Databricksは基盤となるモデルに基づいてシステムプロンプトを適切にフォーマットします。

より高度なデモにアクセスしたいですか?

ファインチューニングAPIの使用方法をすでに知っている場合は、直接高度なデモにジャンプできます:

顧客のチケットを分類して解決までの時間を短縮するためのファインチューニング

このデモでは、緊急/重要なチケットを分類し、それらをキューの上位に配置するためにLLMを特化させる方法をご紹介します。

小規模なLlama3-7Bをファインチューニングして、精度を向上させつつコストを削減します。

そのために、Databricksはモデルをファインチューニングし、そのパフォーマンスを評価するためのシンプルで組み込みのAPIを提供しています。さあ、始めましょう!

Databricksでのファインチューニングに関するドキュメント:

# ライブラリをインストールする
%pip install --quiet databricks-genai==1.0.8 mlflow==2.14.2
%pip install --quiet databricks-sdk==0.29.0
dbutils.library.restartPython()
%run ./_resources/00-setup

トレーニングデータの準備

現在のサポートチケットを確認しましょう

チケットのテキスト(メールなど)に基づいて、モデルをトレーニングして優先度を予測することを目指します:

%sql
select * from customer_tickets limit 10

Screenshot 2024-10-08 at 20.15.59.png

チケットの分類のためのプロンプト作成

チケットを分類するためのプロンプト例を作成しましょう。できるだけ良い結果を得るために、少数の例を用いたプロンプト例を含めます:

system_prompt = """
### Instruction
You are tasked as a service request processing agent responsible for reading the descriptions of tickets and categorizing each into one of the predefined categories. Please categorize each ticket into one of the following specific categories: 

Not Urgent 
Urgent
Impacting the Prod

Do not create or use any categories other than those explicitly listed above. Return only single category as response. If there is confusion between multiple categories error on the side of assigning higher severity.

Impacting the Prod is more severe than Urgent
Urgent is more severe than Not Urgent

###Example Input
We have noticed an issue with our Databricks workspace objects, specifically with clusters. Some of our production ETL pipelines and ad-hoc analytics jobs are being affected. The clusters seem to be unresponsive and we are unable to run any commands. This is impacting our prod and we need urgent assistance.

###Response
Impacting the Prod

Based on the above categorize the following issue: \n\n"""

スタンドアロンのmixtral 8x7bモデルを使って、まずはバニラモデルでテストしてみましょう。ご覧の通り、これは理想的ではありません。多くのテキストが追加されており、データセットを適切に分類していません。

spark.sql(f"""SELECT 
            ai_query("databricks-mixtral-8x7b-instruct", concat("{system_prompt}", description)) AS mixtral_small_classification,
            description
        FROM customer_tickets 
        LIMIT 5""").display()

Screenshot 2024-10-08 at 20.17.03.png

チャットコンプリーション用のデータセットの準備

データブリックスでは、コンプリーションAPIの使用を常に推奨しています。データブリックスが最終的なトレーニングプロンプトを適切にフォーマットしてくれるためです。

チャットコンプリーションには、OpenAIの標準に従ったrolepromptのリストが必要です。この標準は、入力をLLMの指示パターンに従ったプロンプトに変換する利点があります。

各基礎モデルは異なる指示タイプでトレーニングされる場合があるため、ファインチューニング時には同じタイプを使用することが最適です。

可能な限り、チャットコンプリーションを使用することをお勧めします。

[
  {"role": "system", "content": "[system prompt]"},
  {"role": "user", "content": "Here is a documentation page:[RAG context]. Based on this, answer the following question: [user question]"},
  {"role": "assistant", "content": "[answer]"}
]

ファインチューニングデータセットはあなたのRAGアプリケーションで使用するのと同じフォーマットである必要があることを覚えておいてください。

トレーニングデータのタイプ

Databricksでは様々な種類のデータセットフォーマット(ボリュームのファイル、Deltaテーブル、公開されている.jsonlフォーマットのHugging Faceデータセット)をサポートしていますが、プロダクションの品質を確実にするために、適切なデータパイプラインの一部として、Unity Catalog内のDeltaテーブルとしてデータセットを準備することをお勧めします。

覚えておいてください、このステップは重要なものであり、ご自身のトレーニングデータセットが高品質であることを確実にする必要があります。

最終的なチャットコンプリーションデータセットの作成をサポートする小さなpandas UDFを作成しましょう。

spark.sql(f"""
CREATE OR REPLACE TABLE ticket_priority_training_dataset AS
SELECT 
    ARRAY(
        STRUCT('user' AS role, CONCAT('{system_prompt}', '\n', description) AS content),
        STRUCT('assistant' AS role, priority AS content)
    ) AS messages
FROM customer_tickets;
""")

spark.table('ticket_priority_training_dataset').display()

Screenshot 2024-10-08 at 20.17.57.png

ファインチューニングの実行を開始する

トレーニングが完了すると、モデルは自動的にUnity Catalogに保存され、サービスとして利用できるようになります!

このデモでは、作成したテーブル上のAPIを使用してLLMをプログラム的にファインチューニングします。

ただし、UIから新しいファインチューニング実験を作成することもできます!

from databricks.model_training import foundation_model as fm
import mlflow

mlflow.set_registry_uri("databricks-uc")

base_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

# モデル名を綺麗にしましょう
registered_model_name = f"{catalog}.{db}.classif_" + re.sub(r'[^a-zA-Z0-9]', '_',  base_model_name)
    
run = fm.create(
    data_prep_cluster_id=get_current_cluster_id(),  
    model=base_model_name,  
    train_data_path=f'{catalog}.{db}.ticket_priority_training_dataset',
    task_type="CHAT_COMPLETION",  
    training_duration="10ep",  # デモなので10エポックのみ
    register_to=registered_model_name,
    learning_rate="5e-7",
)

print(run)

MLFlowエクスペリメントを通じてファインチューニングの実行を追跡する

進行中または過去のファインチューニング実行の進捗を監視するには、MLFlowエクスペリメントから実行を開くことができます。ここでは、将来の実行を改善するためにどのように調整するかについての貴重な情報を見つけることができます。例えば:

  • 実行の終わりにモデルがまだ改善している場合は、エポック数を増やす
  • 損失が減少しているが非常に遅い場合は、学習率を上げる
  • 損失が大きく変動している場合は、学習率を下げる
displayHTML(f'Open the <a href="/ml/experiments/{run.experiment_id}/runs/{run.run_id}/model-metrics">training run on MLFlow</a> to track the metrics')
# ランの詳細の追跡
display(run.get_events())

# ランが完了するまで待つヘルパー関数 - 詳細は _resources フォルダをご覧ください
wait_for_run_to_finish(run)

Screenshot 2024-10-08 at 20.20.06.png

ファインチューニングされたモデルがUnity Catalogに登録されています。

Screenshot 2024-10-08 at 20.21.31.png

ファインチューニングされたモデルをサービングエンドポイントにデプロイ

準備ができたら、モデルはUnity Catalogで利用可能になります。

ここから、UIを使用してモデルをデプロイすることも、APIを使用することもできます。再現性のために、以下でAPIを使用します:

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ServedEntityInput, EndpointCoreConfigInput

serving_endpoint_name = "dbdemos_classification_fine_tuned"
w = WorkspaceClient()
endpoint_config = EndpointCoreConfigInput(
    name=serving_endpoint_name,
    served_entities=[
        ServedEntityInput(
            entity_name=registered_model_name,
            entity_version=get_latest_model_version(registered_model_name),
            min_provisioned_throughput=0, # エンドポイントがスケールダウンできる最小のトークン数/秒
            max_provisioned_throughput=100,# エンドポイントがスケールアップできる最大のトークン数/秒
            scale_to_zero_enabled=True
        )
    ]
)

force_update = False # Trueに設定すると、新しいバージョンに更新されます(デフォルトではデモではエンドポイントを新しいモデルバージョンに更新しません)
try:
  existing_endpoint = w.serving_endpoints.get(serving_endpoint_name)
  print(f"エンドポイント {serving_endpoint_name} は既に存在します - 強制更新 = {force_update}...")
  if force_update:
    w.serving_endpoints.update_config_and_wait(served_entities=endpoint_config.served_entities, name=serving_endpoint_name)
except:
    print(f"エンドポイント {serving_endpoint_name} を作成中です。エンドポイントのパッケージ化とデプロイには数分かかります...")
    w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config)

数分待つと、エンドポイントが起動します。

Screenshot 2024-10-08 at 20.32.02.png

モデルのエンドポイントのテスト

これで、ファインチューニングされたモデルを提供し、質問を始める準備が整いました!

応答は、Databricksのドキュメントと私たちのRAGチャットボットのフォーマットされた出力から改善され、特化されます!

df = spark.sql(f"""
        SELECT 
            ai_query("dbdemos_classification_fine_tuned", concat("{system_prompt}", description)) AS fine_tuned_prediction,
            description,
            email
        FROM customer_tickets 
        LIMIT 5""")
display(df)

期待した通りに簡潔な結果を返すようになりました!

Screenshot 2024-10-08 at 20.34.59.png

次のステップ

このノートブックでは、Databricks Mosaic AI FT APIを使用したファインチューニングの基礎をカバーしました。

本番環境に展開する前に、ファインチューニングモデルをテストして評価するために通常いくつかの追加のステップが必要です。

以下の2つのユースケースを探索して、次のことを発見してください:

チャットボット/アシスタントRAGモデルのファインチューニング

03.1-llm-rag-fine-tuning ノートブックを開いて、Databricksの組み込み評価機能を使用してLLMを評価する方法を探索してください。

次のサンプルでは、こちらのノートブックの一部を使うので、こちらから取り組んだ方がいいです。

エンティティの抽出と評価

02.1-llm-entity-extraction-drug-fine-tuning ノートブックを開いて、Fine TunedモデルとベースモデルのNamed Entity Extraction(NER)の例をベンチマークにしてください。

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?