LoginSignup
0
0

Databricksにおける基盤モデルの指示ファインチューニング

Last updated at Posted at 2024-06-01

instruction-fine-tuning/01-llm-instruction-drug-extraction-fine-tuningをウォークスルーします。

基盤モデルの指示ファインチューニング: 固有表現抽出(Named Entity Recognition)

このデモでは、指示ファインチューニング(ITF)のためのモデルのチューニングにフォーカスし、テキストから薬品名を抽出する様にllama2を特化させます。このプロセスはNER(Named Entity Recognition)と呼ばれるものです。

オープンソースモデルを医療NERタスクにファインチューンすることで、モデルの出力は

  1. より正確に、かつ
  2. より効率的になりモデルサービング費用を削減します。

データセットの準備

この例をシンプルにするために、Huggingfaceにある既存のNERデータセットを使います。商用アプリケーションにおいては、モデルのパフォーマンスを改善するのに十分なサンプルを取得し、データのラベリングに投資することが多くの場合合理的なものとなります。

指示ファインチューニングでのデータ準備が鍵となります。Databricks Mosaic AIリサーチチームは、トレーニングデータのキュレーション戦略の確立における有用なガイドラインを公開しています。

%run ../_resources/00-setup
from datasets import load_dataset
import pandas as pd

hf_dataset_name = "allenai/drug-combo-extraction"

dataset_test = load_dataset(hf_dataset_name, split="test")

# データセットをpandasデータフレームに変換
df_test = pd.DataFrame(dataset_test)

# spansからエンティティを抽出
df_test["human_annotated_entities"] = df_test["spans"].apply(lambda spans: [span["text"] for span in spans])

df_test = df_test[["sentence", "human_annotated_entities"]]

display(df_test)

Screenshot 2024-06-01 at 19.34.01.png

エンティティを抽出するプロンプトテンプレートの構築

system_prompt = """
### INSTRUCTIONS:
You are a medical and pharmaceutical expert. Your task is to identify pharmaceutical drug names from the provided input and list them accurately. Follow these guidelines:

1. Do not add any commentary or repeat the instructions.
2. Extract the names of pharmaceutical drugs mentioned.
3. Place the extracted names in a Python list format. Ensure the names are enclosed in square brackets and separated by commas.
4. Maintain the order in which the drug names appear in the input.
5. Do not add any text before or after the list.
"""

ベースラインバージョンでエンティティを抽出 (ファインチューンなし)

ファインチューンしていないモデルであるベースラインを用いて最初のエンティティ抽出を行うところからスタートしましょう。

コストを節約するために、以前の../02-llm-evaluationノートブックで作成したdbdemos_llm_not_fine_tunedエンドポイントを使います。

事前にエンドポイントをセットアップしてこのノートブックを実行する様にしてください。

import mlflow
from langchain_community.chat_models.databricks import ChatDatabricks
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate

# ../02-llm-evaluation ノートブックで使用したモデルに対応するモデルであることを確認してください
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"

input_sentence = "{sentence}"

def build_chain(llm):
    # Mistralではsystemロールをサポートしていません
    if "mistral" in base_model_name:
        messages = [("user", f"{system_prompt} \n {input_sentence}")]
    else:
        messages = [("system", system_prompt),
                    ("user", input_sentence)]
    return ChatPromptTemplate.from_messages(messages) | llm | StrOutputParser()
  
def extract_entities(df, endpoint_name):
  llm = ChatDatabricks(endpoint=endpoint_name, temperature=0.1)
  chain = build_chain(llm)
  predictions = chain.with_retry(stop_after_attempt=2) \
                                      .batch(df[["sentence"]].to_dict(orient="records"), config={"max_concurrency": 4})
  # テキストから配列を抽出。詳細は ../resource ノートブックをご覧ください。
  
  cleaned_predictions = [extract_json_array(p) for p in predictions]
  return predictions, cleaned_predictions 

# ベンチマークのメトリクスを収集するためにテストセットからいくつかのサンプルのみを取得
from sklearn.model_selection import train_test_split

df_validation, df_test_small = train_test_split(df_test, test_size=0.2, random_state=42)

# このエンドポイントは ../02-llm-evaluation ノートブックで作成されたものです。ベースラインとなるmistral 7bモデルでありファインチューンされていません。
# 最初にベースラインのモデルをデプロイするために事前にノートブックを実行してください。
serving_endpoint_baseline_name = "taka_dbdemos_llm_not_fine_tuned"

predictions, cleaned_predictions = extract_entities(df_test_small, serving_endpoint_baseline_name)
df_test_small['baseline_predictions'] = predictions
df_test_small['baseline_predictions_cleaned'] = cleaned_predictions
display(df_test_small[["sentence", "baseline_predictions_cleaned", "human_annotated_entities"]])

Screenshot 2024-06-01 at 19.42.09.png

ベースラインモデルの評価

我々のモデルがそれなりの数のエンティティを抽出していることを確認できますが、推論の前後でいくつかのランダムなテキストも追加してしまっています。

エンティティ抽出における精度と再現率

精度と再現率を計算することでモデルのベンチマークを行います。我々のテストデータセットのそれぞれの文でこれらの値を計算しましょう。

from sklearn.metrics import precision_score, recall_score

def compute_precision_recall(prediction, ground_truth):
    prediction_set = set([str(drug).lower() for drug in prediction])
    ground_truth_set = set([str(drug).lower() for drug in ground_truth])
    all_elements = prediction_set.union(ground_truth_set)

    # セットをバイナリのリストに変換
    prediction_binary = [int(element in prediction_set) for element in all_elements]
    ground_truth_binary = [int(element in ground_truth_set) for element in all_elements]
    
    precision = precision_score(ground_truth_binary, prediction_binary)
    recall = recall_score(ground_truth_binary, prediction_binary)

    return precision, recall
  
def precision_recall_series(row):
  precision, recall = compute_precision_recall(row['baseline_predictions_cleaned'], row['human_annotated_entities'])
  return pd.Series([precision, recall], index=['precision', 'recall'])

df_test_small[['baseline_precision', 'baseline_recall']] = df_test_small.apply(precision_recall_series, axis=1)
df_test_small[['baseline_precision', 'baseline_recall']].describe()

Screenshot 2024-06-01 at 19.42.51.png

このサンプルでは、ベースラインのLLMはおおよそ0.6936の再現率を示しており、テキストに存在する実際のすべての薬品名の69.36%をうまく特定できていることを確認できます。このメトリックは、ヘルスケアや関連領域において重要であり、薬品名を見逃すことは不完全あるいは不適切な情報処理につながります。

平均70.45%の精度はベースラインのLLMモデルが薬品名として識別したトークンや文の約70.45%が正しかったことを意味します。

モデルのファインチューニング

ファインチューニングデータの準備

ファインチューニングの前にトレーニングデータセットのサンプルにプロンプトテンプレートを適用し、ターゲットとするリストフォーマットに正解の薬品リストを抽出する必要があります。

我々はこれをDatabricksのカタログにテーブルとして保存します。通常、これは完全なデータエンジニアリングパイプラインの一部となります。

このステップはファインチューニングにおいて鍵となるものであり、ご自身のトレーニングデータセットを高品質にしてください!

dataset_train = load_dataset(hf_dataset_name, split="train")
df_train = pd.DataFrame(dataset_train)

# データセットをpandasデータフレームに変換
df_train = pd.DataFrame(df_train)

# spansからエンティティを抽出
df_train["human_annotated_entities"] = df_train["spans"].apply(lambda spans: [span["text"] for span in spans])

df_train = df_train[["sentence", "human_annotated_entities"]]

df_train

Screenshot 2024-06-01 at 19.44.09.png

from pyspark.sql.functions import pandas_udf, to_json
import pandas as pd

@pandas_udf("array<struct<role:string, content:string>>")
def create_conversation(sentence: pd.Series, entities: pd.Series) -> pd.Series:
    def build_message(s, e):
        # 固有のモデルの挙動をチェックするロジックを調整
        if "mistral" in base_model_name:
            # systemプロンプトなしのMistral固有の挙動
            return [
                {"role": "user", "content": f"{system_prompt} \n {s}"},
                {"role": "assistant", "content": e}]
        else:
            # systemプロンプトありのデフォルトの挙動
            return [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": str(s)},
                {"role": "assistant", "content": e}]
                
    # 文とエンティティのそれぞれのペアに build_message を適用
    return pd.Series([build_message(s, e) for s, e in zip(sentence, entities)])

# df_trainが定義され、カラム'sentence'と'entities'を持つSparkデータフレームとして適切にフォーマットされていることが前提
training_dataset = spark.createDataFrame(df_train).withColumn("human_annotated_entities", to_json("human_annotated_entities"))

# UDFの適用、テーブルへの書き込み、表示
training_dataset.select(create_conversation("sentence", "human_annotated_entities").alias('messages')).write.mode('overwrite').saveAsTable("ner_chat_completion_training_dataset")
display(spark.table("ner_chat_completion_training_dataset"))

Screenshot 2024-06-01 at 19.59.09.png

評価データセットも準備します。df_validationとして利用できる様にします。

eval_dataset = spark.createDataFrame(df_validation).withColumn("human_annotated_entities", to_json("human_annotated_entities"))

# UDFの適用、テーブルへの書き込み、表示
eval_dataset.select(create_conversation("sentence", "human_annotated_entities").alias('messages')).write.mode('overwrite').saveAsTable("ner_chat_completion_eval_dataset")
display(spark.table("ner_chat_completion_eval_dataset"))

Screenshot 2024-06-01 at 20.00.37.png

ファインチューニング

データの準備ができたらあとはファインチューニングAPIを呼び出すだけです。

from databricks.model_training import foundation_model as fm

# テストの後にモデル名をdrug_extraction_ftに戻します。
registered_model_name = f"{catalog}.{db}.drug_extraction_ft_" + re.sub(r'[^a-zA-Z0-9]', '_',  base_model_name.lower())

run = fm.create(
  data_prep_cluster_id = get_current_cluster_id(),  # トレーニングデータソースとしてDeltaテーブルを使用している場合には必須。データ準備ジョブで使いたいクラスターのIDとなります。詳細は ./_resources をご覧ください。
  model=base_model_name,
  train_data_path=f"{catalog}.{db}.ner_chat_completion_training_dataset",
  eval_data_path=f"{catalog}.{db}.ner_chat_completion_eval_dataset",
  task_type = "CHAT_COMPLETION",
  register_to=registered_model_name,
  training_duration='50ep' # ファインチューニングランの期間、デモをクイックに行う際は10エポックのみ。いつ止めるのか(平坦になった際)を判断するにはトレーニングメトリクスをチェックしてください。
)
print(run)

MLflowエクスペリメントを通じたモデルのファインチューニングのトラッキング

ご自身のファインチューニングエクスペリメントを追跡するためにMLflowエクスペリメントのランをオープンすることができます。これは、トレーニングランをどのようにチューニングするのかを知るのに役立ちます(例: ランの最後でモデルに改善の余地がある場合にはさらにエポックを追加する)。

displayHTML(f'Open the <a href="/ml/experiments/{run.experiment_id}/runs/{run.run_id}/model-metrics">training run on MLflow</a> to track the metrics')
display(run.get_events())

# ランが終了するのを待つヘルパー関数 - 詳細は _resources フォルダをご覧ください
wait_for_run_to_finish(run)

Screenshot 2024-06-01 at 20.11.24.png

Lossも順調に減っています。
Screenshot 2024-06-01 at 20.09.11.png

ファインチューニングしたモデルをサービングエンドポイントにデプロイ

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ServedEntityInput, EndpointCoreConfigInput, AutoCaptureConfigInput

serving_endpoint_name = "taka_dbdemos_llm_drug_extraction_fine_tuned"
w = WorkspaceClient()
endpoint_config = EndpointCoreConfigInput(
    name=serving_endpoint_name,
    served_entities=[
        ServedEntityInput(
            entity_name=registered_model_name,
            entity_version=get_latest_model_version(registered_model_name),
            min_provisioned_throughput=0, # エンドポイントがスケールダウンする最小秒間トークン数
            max_provisioned_throughput=100, # エンドポイントがスケールアップする最大秒間トークン数
            scale_to_zero_enabled=True
        )
    ],
    auto_capture_config = AutoCaptureConfigInput(catalog_name=catalog, schema_name=db, enabled=True, table_name_prefix="fine_tuned_drug_extraction_llm_inference")
)

force_update = False # 新規バージョンをリリースする際にはこれを True に設定(このデモではデフォルトで新規モデルバージョンにエンドポイントを更新しません)
existing_endpoint = next(
    (e for e in w.serving_endpoints.list() if e.name == serving_endpoint_name), None
)
if existing_endpoint == None:
    print(f"Creating the endpoint {serving_endpoint_name}, this will take a few minutes to package and deploy the endpoint...")
    w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config)
else:
  print(f"endpoint {serving_endpoint_name} already exist...")
  if force_update:
    w.serving_endpoints.update_config_and_wait(served_entities=endpoint_config.served_entities, name=serving_endpoint_name)

ファインチューニング後の評価

ファインチューニングしたモデルはUnity Catalogに登録され、数クリックだけでエンドポイントにデプロイされました。

再現率と精度のベンチマーク

それでは再度評価を行い、ベースラインモデルと新たなモデルの精度と再現率を比較しましょう。

# 新たにファインチューニングしたエンドポイントに対して推論を実行
predictions, cleaned_predictions = extract_entities(df_test_small, serving_endpoint_name)
df_test_small['fine_tuned_predictions'] = predictions
df_test_small['fine_tuned_predictions_cleaned'] = cleaned_predictions
display(df_test_small[["sentence", "human_annotated_entities", "baseline_predictions_cleaned", "fine_tuned_predictions_cleaned"]])

Screenshot 2024-06-01 at 20.36.45.png

# 新たなモデルで精度と再現率を計算
def precision_recall_series(row):
  precision, recall = compute_precision_recall(row['fine_tuned_predictions_cleaned'], row['human_annotated_entities'])
  return pd.Series([precision, recall], index=['precision', 'recall'])

df_test_small[['fine_tuned_precision', 'fine_tuned_recall']] = df_test_small.apply(precision_recall_series, axis=1)
df_test_small[['baseline_precision', 'fine_tuned_precision', 'baseline_recall', 'fine_tuned_recall']].describe()

ファインチューニングによって、精度(precision)、再現率(recall)共に改善されています。
Screenshot 2024-06-01 at 20.37.30.png

トークン出力の計測

我々の最初のモデルは結果の前後に不必要なテキストを追加してしまっていました。これは分析やパースを困難にするだけでなく、不必要なトークンの課金が発生することになります。

我々の新たなモデルが期待通りに動作するのかを見てみましょう。

df_test_small['baseline_predictions_len'] = df_test_small['baseline_predictions'].apply(lambda x: len(x))
df_test_small['fine_tuned_predictions_len'] = df_test_small['fine_tuned_predictions'].apply(lambda x: len(x))
df_test_small[['baseline_predictions_len', 'fine_tuned_predictions_len']].describe()

Screenshot 2024-06-01 at 20.38.35.png

精度を改善したことに加え、追加のテキストを削除することで出力(すなわちコスト)を削減できています!

まとめ

このノートブックでは、Databricksgaどのようにして固有表現抽出のための指示ファインチューニングを用いたファインチューニングとLLMのデプロイメントをシンプルにするのかを見てきました。

Databricksが、どのようにしてベースラインモデルとファインチューニングしたモデルとの間でのパフォーマンス改善を容易に評価できるようにするのかをカバーしました。

ファインチューニングは様々なユースケースに適用することができます。Chat APIを用いることで、システムはすぐに利用できるプロンプトを構成し、ファインチューニングをシンプルにしてくれるので、可能な際には常に使う様にしましょう!

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0