こちらのサンプルノートブックをウォークスルーします。
DSPyに関してはこちらをご覧ください。ちなみに、DSPyはDeclarative Self-improving Python(宣言型自己改善Python)の略だそうです。なるほど。
テキスト分類器のDSPyプログラムの作成
このノートブックでは、DSPyとDatabricksがホスティングする大規模言語モデルを使用してテキスト分類プログラムを作成する方法を示します。
どのように機能するのか?
DSPyは、手動のプロンプトエンジニアリングを構造化されたテキスト変換グラフに置き換えることで、言語モデル(LM)パイプラインの構築を簡素化します。これらのグラフは、推論、検索、複雑な質問への回答などのLMタスクを自動化および最適化する柔軟な学習モジュールを使用します。
ハイレベルでは、DSPyはプロンプトを最適化し、最適な言語モデルを選択し、トレーニングデータを使用してモデルをファインチューンすることができます。
プロセスは、ほとんどのDSPy オプティマイザーに共通する次の3つのステップに従います:
-
候補生成: DSPyはプログラム内のすべての
Predict
モジュールを見つけ、プロンプトの例などの指示やデモンストレーションのバリエーションを生成します。このステップでは、次の段階のための候補セットを作成します。 - パラメータ最適化: 次にDSPyは、ランダムサーチ、TPE、Optunaなどの方法を使用して最適な構成を選択します。この段階でモデルのファインチューンも行うことができます。
- 高次最適化: 最後に、DSPyはプログラムの構造を変更することができ、例えば、プログラムの異なるバリエーションのアンサンブルを作成してパフォーマンスを向上させることができます。
パッケージのインストール
%pip install -U dspy-ai>=0.2.5
dbutils.library.restartPython()
LLMのセットアップ
関連する依存関係をインストールした後、以下の例では、Databricksの基盤モデルサービングエンドポイントを使用して、Meta Llama 3.3 70B Instructモデルを利用します。
import dspy
# Databricksモデルサービングエンドポイント名を定義
model_serving_endpoint_name = "databricks-meta-llama-3-3-70b-instruct"
# Databricksコンテキスト内のデフォルト認証を活用(ノートブック、ワークフローなど)
llama_33_70b = dspy.LM(
model=f"databricks/{model_serving_endpoint_name}",
max_tokens=500,
temperature=0.1,
)
dspy.settings.configure(lm=llama_33_70b)
データのセットアップ
以下では、HuggingfaceからReuters 21578データセットをダウンロードし、トレーニングセットとテストセットのスプリットが同じラベルを持つようにするユーティリティを作成します。
import numpy as np
import pandas as pd
from dspy.datasets.dataset import Dataset
from sklearn.model_selection import StratifiedShuffleSplit
def read_data_and_subset_to_categories() -> tuple[pd.DataFrame]:
"""
reuters-21578データセットを読み込みます。ドキュメントは以下のURLにあります:
https://huggingface.co/datasets/yangwang825/reuters-21578
"""
# トレイン/テスト分割を読み込み
file_path = "hf://datasets/yangwang825/reuters-21578/{}.json"
train = pd.read_json(file_path.format("train"))
test = pd.read_json(file_path.format("test"))
# ラベルをクリーンアップ
label_map = {
0: "acq",
1: "crude",
2: "earn",
3: "grain",
4: "interest",
5: "money-fx",
6: "ship",
7: "trade",
}
train["label"] = train["label"].map(label_map)
test["label"] = test["label"].map(label_map)
return train, test
class CSVDataset(Dataset):
def __init__(
self, n_train_per_label: int = 20, n_test_per_label: int = 10, *args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.n_train_per_label = n_train_per_label
self.n_test_per_label = n_test_per_label
self._create_train_test_split_and_ensure_labels()
def _create_train_test_split_and_ensure_labels(self) -> None:
"""`dev`にあるラベルが`train`にも含まれるようにトレイン/テスト分割を行います。"""
# データを読み込み
train_df, test_df = read_data_and_subset_to_categories()
# 各ラベルごとにサンプリング
train_samples_df = pd.concat([
group.sample(n=self.n_train_per_label)
for _, group in train_df.groupby('label')
])
test_samples_df = pd.concat([
group.sample(n=self.n_test_per_label)
for _, group in test_df.groupby('label')
])
# DSPyクラス変数を設定
self._train = train_samples_df.to_dict(orient="records")
self._dev = test_samples_df.to_dict(orient="records")
# ブートストラップの価値を示すために小さなデータセットに制限
dataset = CSVDataset(n_train_per_label=3, n_test_per_label=1)
# DSPyを含むトレーニングセットとテストセットを作成
# 期待される入力値の名前を指定する必要があることに注意
train_dataset = [example.with_inputs("text") for example in dataset.train]
test_dataset = [example.with_inputs("text") for example in dataset.dev]
print(len(train_dataset), len(test_dataset))
print(f"Train labels: {set([example.label for example in dataset.train])}")
print(train_dataset[0])
24 8
Train labels: {'money-fx', 'trade', 'earn', 'interest', 'crude', 'grain', 'ship', 'acq'}
Example({'label': 'interest', 'text': 'belgium cuts discount rate to pct from official'}) (input_keys={'text'})
DSPyシグネチャとモジュールのセットアップ
最後に、テキスト分類タスクを定義します。
DSPyシグネチャの動作にガイドラインを提供する方法はさまざまです。現在、DSPyではユーザーが以下を指定することができます:
- クラスのドキュメンテーション文字列を介した高レベルの目標。
- オプションのメタデータを含む入力フィールドのセット。
- オプションのメタデータを含む出力フィールドのセット。
DSPyはこの情報を活用して最適化を行います。
次の例では、TextClassificationSignature
クラスに対象データセットに関する情報がないことに注意してください。これは、文脈なしでテキストを分類するように指示されたベースのLLMを効果的に活用していることを意味します。この白紙の状態から、DSPyを使用してトレーニングのみで分類を学習することができます。
本番環境では、トレーニング時間を短縮するためにシグネチャにメタデータを提供する必要があることに注意してください。このシナリオはデモンストレーション目的のためのものです。
# 入力フィールドと出力フィールドを定義
class TextClassificationSignature(dspy.Signature):
text = dspy.InputField()
label = dspy.OutputField(desc="Label of predicted_class")
# モデルクラスを定義
class TextClassifier(dspy.Module):
def __init__(self):
super().__init__()
self.generate_classification = dspy.Predict(TextClassificationSignature)
def forward(self, text: str):
return self.generate_classification(text=text)
Hello worldの例を実行する
以下は、DSPyモジュールと関連するシグネチャを使用して予測する方法を示しています。予想通り、プログラムはラベルが与えられていないため、テキストを誤って分類します。
from copy import copy
# impact_improvementクラスの初期化
text_classifier = copy(TextClassifier())
message = "I am interested in space"
print(text_classifier(text=message))
message = "I enjoy ice skating"
print(text_classifier(text=message))
Prediction(
label='Space Enthusiast'
)
Prediction(
label='positive'
)
トレーニング
トレーニングには、トレーニングセットからブートストラップサンプルを取り、ランダムサーチ戦略を活用して予測精度を最適化するオプティマイザであるBootstrapFewShotWithRandomSearchを使用できます。
次の例では、validate_classification
で定義されているように、単純なメトリック定義である完全一致を使用していますが、dspy.Metricsには、精度を適切に評価するための複雑なロジックやLMベースのロジックが含まれている場合があります。
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
# 完全一致で分類結果を検証
def validate_classification(example, prediction, trace=None) -> bool:
return example.label == prediction.label
# Optimizerの初期化
optimizer = BootstrapFewShotWithRandomSearch(
metric=validate_classification, # 検証関数を指定
num_candidate_programs=5, # 候補プログラムの数
max_bootstrapped_demos=2, # ブートストラップデモの最大数
num_threads=1, # 使用するスレッド数
)
# モデルのコンパイル
compiled_pe = optimizer.compile(copy(TextClassifier()), trainset=train_dataset)
Going to sample between 1 and 2 traces per predictor.
Will attempt to bootstrap 5 candidate sets.
:
:
Average Metric: 0.00 / 24 (0.0%): 100%|██████████| 24/24 [00:10<00:00, 2.21it/s]2025/01/14 11:38:58 INFO dspy.evaluate.evaluate: Average Metric: 0.0 / 24 (0.0%)
New best score: 0.0 for seed -3
Scores so far: [0.0]
Best score so far: 0.0
Average Metric: 23.00 / 24 (95.8%): 100%|██████████| 24/24 [00:22<00:00, 1.05it/s]2025/01/14 11:39:21 INFO dspy.evaluate.evaluate: Average Metric: 23 / 24 (95.8%)
New best score: 95.83 for seed -2
Scores so far: [0.0, 95.83]
Best score so far: 95.83
8%|▊ | 2/24 [00:00<00:09, 2.32it/s]
Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.
Average Metric: 22.00 / 24 (91.7%): 100%|██████████| 24/24 [00:22<00:00, 1.08it/s]2025/01/14 11:39:44 INFO dspy.evaluate.evaluate: Average Metric: 22 / 24 (91.7%)
Scores so far: [0.0, 95.83, 91.67]
Best score so far: 95.83
12%|█▎ | 3/24 [00:02<00:16, 1.26it/s]
Bootstrapped 2 full traces after 3 examples for up to 1 rounds, amounting to 3 attempts.
Average Metric: 21.00 / 24 (87.5%): 100%|██████████| 24/24 [00:23<00:00, 1.03it/s]2025/01/14 11:40:10 INFO dspy.evaluate.evaluate: Average Metric: 21 / 24 (87.5%)
Scores so far: [0.0, 95.83, 91.67, 87.5]
Best score so far: 95.83
4%|▍ | 1/24 [00:00<00:20, 1.13it/s]
Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.
Average Metric: 22.00 / 24 (91.7%): 100%|██████████| 24/24 [00:19<00:00, 1.23it/s]2025/01/14 11:40:30 INFO dspy.evaluate.evaluate: Average Metric: 22 / 24 (91.7%)
Scores so far: [0.0, 95.83, 91.67, 87.5, 91.67]
Best score so far: 95.83
4%|▍ | 1/24 [00:00<00:22, 1.02it/s]
Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.
Average Metric: 22.00 / 24 (91.7%): 100%|██████████| 24/24 [00:22<00:00, 1.07it/s]2025/01/14 11:40:54 INFO dspy.evaluate.evaluate: Average Metric: 22 / 24 (91.7%)
Scores so far: [0.0, 95.83, 91.67, 87.5, 91.67, 91.67]
Best score so far: 95.83
4%|▍ | 1/24 [00:01<00:23, 1.03s/it]
Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.
Average Metric: 24.00 / 24 (100.0%): 100%|██████████| 24/24 [00:20<00:00, 1.17it/s]2025/01/14 11:41:15 INFO dspy.evaluate.evaluate: Average Metric: 24 / 24 (100.0%)
New best score: 100.0 for seed 3
Scores so far: [0.0, 95.83, 91.67, 87.5, 91.67, 91.67, 100.0]
Best score so far: 100.0
8%|▊ | 2/24 [00:01<00:16, 1.37it/s]
Bootstrapped 1 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.
Average Metric: 24.00 / 24 (100.0%): 100%|██████████| 24/24 [00:20<00:00, 1.20it/s]2025/01/14 11:41:37 INFO dspy.evaluate.evaluate: Average Metric: 24 / 24 (100.0%)
Scores so far: [0.0, 95.83, 91.67, 87.5, 91.67, 91.67, 100.0, 100.0]
Best score so far: 100.0
8 candidate programs found.
DSPyトレーニングが返すパスを解析するのは難しい場合がありますが、いくつかのAPIが公開されています。以下は、n
個のメッセージのフォーマットされたカラフルなメッセージ履歴を表示する方法を示しています。
llama_31_70b.inspect_history(n=1)
[2025-01-14T11:41:37.150879]
System message:
Your input fields are:
1. `text` (str)
Your output fields are:
1. `label` (str): Label of predicted_class
All interactions will be structured in the following way, with the appropriate values filled in.
[[ ## text ## ]]
{text}
[[ ## label ## ]]
{label}
[[ ## completed ## ]]
In adhering to this structure, your objective is:
Given the fields `text`, produce the fields `label`.
User message:
[[ ## text ## ]]
group seeks prime medical pmsi holder list a group of investment firms led by far hills n j investor natalie koether said it is seeking information about the shareholders of prime medical services inc over which it has said it is seeking control in a filing with the securities and exchange commission the group which includes shamrock associates included a march letter to prime which asks for a complete list of all shareholders and their addresses the group said it wants the information so it can contact shareholders on issues including election of an opposition slate of directors to the board and other proxy contests the koether group s letter gives the company five days to respond to its request if there is no response from prime medical the group said it would consider the demand refused and would take other proper steps to get the information the group said it already holds prime medical shares or pct of the total it said it has taken legal action to try to force the company to set an annual meeting and require all directors to stand for election in a previous sec filing the group has said it has decided to try to seek control of prime medical through a tender offer exchange offer proxy contest or other ways reuter
Respond with the corresponding output fields, starting with the field `[[ ## label ## ]]`, and then ending with the marker for `[[ ## completed ## ]]`.
Assistant message:
[[ ## label ## ]]
acq
[[ ## completed ## ]]
:
:
モデルの精度を比較する
最後に、訓練されたモデルが未知のテストデータをどれだけ正確に予測できるかを確認できます。
def check_accuracy(classifier, test_data: pd.DataFrame = test_dataset) -> float:
residuals = []
predictions = []
for example in test_data:
prediction = classifier(text=example["text"])
residuals.append(int(validate_classification(example, prediction)))
predictions.append(prediction)
return residuals, predictions
# 未コンパイルモデルの精度をチェック
uncompiled_residuals, uncompiled_predictions = check_accuracy(copy(TextClassifier()))
display(f"Uncompiled accuracy: {np.mean(uncompiled_residuals)}")
# コンパイル済みモデルの精度をチェック
compiled_residuals, compiled_predictions = check_accuracy(compiled_pe)
display(f"Compiled accuracy: {np.mean(compiled_residuals)}")
'Uncompiled accuracy: 0.0'
'Compiled accuracy: 1.0'
上記のように、コンパイルされていない精度は予想通り0です。ベースのLLMは分類ラベルを認識していませんでした。しかし、トレーニングだけで、プロンプト、デモンストレーション、および入力と出力のシグネチャが更新され、モデルの精度は100%に達しました。
for uncompiled_residual, uncompiled_prediction in zip(uncompiled_residuals, uncompiled_predictions):
is_correct = "Correct" if bool(uncompiled_residual) else "Incorrect"
prediction = uncompiled_prediction.label
print(f"{is_correct} prediction: {' ' * (12 - len(is_correct))}{prediction}")
Incorrect prediction: Economy
Incorrect prediction: War/Conflict
Incorrect prediction: fraud
Incorrect prediction: Earnings Report
Incorrect prediction: Economic
Incorrect prediction: Economics
Incorrect prediction: Economy
Incorrect prediction: International Relations
for compiled_residual, compiled_prediction in zip(compiled_residuals, compiled_predictions):
is_correct = "Correct" if bool(compiled_residual) else "Incorrect"
prediction = compiled_prediction.label
print(f"{is_correct} prediction: {' ' * (12 - len(is_correct))}{prediction}")
Correct prediction: interest
Correct prediction: crude
Correct prediction: money-fx
Correct prediction: earn
Correct prediction: acq
Correct prediction: grain
Correct prediction: trade
Correct prediction: ship
まとめ
これはDSPyの動作を紹介する入門例です。
- より多くのDatabricksの例については、DSPy on Databricksを参照してください。
- オープンソースの例については、DSPyのチュートリアルおよびドキュメントを参照してください。