0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Azure で gpt-4o-mini ファインチューニングでハイパーパラメータ探索

Posted at

Azure の gpt-4o-mini に対してoptuna使ってハイパーパラメータ探索をしました。そのときの記録です。時間なかったので全般的に整理が甘いです。やったのが2025年4月なので、いろいろとアップデートあると思います。
時間があれば、どんなパラメータだったかや知見も追記、または別記事で書きます。

Pythonでのファインチューニング自体はこちらの記事を参照ください。

Python Script

1. パッケージロード

あまり整理できていないので、不要なパッケージを含んでいるかもしれません。
plotlyは明示的に呼んでいないのですが、importしていないとエラーになったので、含めています(少し記憶が曖昧)。matplotlibは不要だったかもしれません。

import json
from logging import getLogger, DEBUG, FileHandler, Formatter
import os
import time

from dotenv import load_dotenv
import matplotlib.pyplot as plt
from openai import AzureOpenAI
import optuna
import pandas as pd
import plotly

2. 初期処理

環境変数を読込、Azure OpenAIのクライアントを設定します。

load_dotenv(override=True)
endpoint = os.getenv("AZURE_OPENAI_TRAIN_ENDPOINT")
key = os.getenv("AZURE_OPENAI_TRAIN_API_KEY")
model_name = "gpt-4o-mini"

client = AzureOpenAI(
  azure_endpoint = endpoint,
  api_key = key,
  api_version = "2024-08-01-preview"
)

ファイル.evnの中身です(値は省略)。

.env
AZURE_OPENAI_TRAIN_API_KEY=
AZURE_OPENAI_TRAIN_ENDPOINT=

3. データ読込

ローカルPCからデータを読み込みます。

DIR = './data/'
TRAIN_FILE = DIR + 'train.jsonl'
VALIDATION_FILE = DIR + 'validation.jsonl'

def read_jsonl(file_name: str) -> list:
    with open(file_name, 'r', encoding='utf-8') as f:
        dataset = [json.loads(line) for line in f]

    print(f"Number of examples in {file_name}: {len(dataset)}")
    print(f"First example in {file_name}:")
    for message in dataset[0]["messages"]:
        print(message)
    return dataset

train_dataset = read_jsonl(TRAIN_FILE)
print("----------")
validation_dataset = read_jsonl(VALIDATION_FILE)

ターミナル出力です。

Number of examples in ./data/train.jsonl: 100
First example in ./data/train.jsonl:
{'role': 'system', 'content': 'システムプロンプト'}
{'role': 'user', 'content': 'ユーザプロンプト'}
{'role': 'assistant', 'content': '正解'}
----------
Number of examples in ./data/validation.jsonl: 10
First example in ./data/validation.jsonl:
{'role': 'system', 'content': 'システムプロンプト'}
{'role': 'user', 'content': 'ユーザプロンプト'}
{'role': 'assistant', 'content': '正解'}

4. ファイルアップロード

ファイルをアップロードします。

# Upload the training and validation dataset files to Azure OpenAI with the SDK.
training_response = client.files.create(
    file = open(TRAIN_FILE, "rb"), purpose="fine-tune"
)

validation_response = client.files.create(
    file = open(VALIDATION_FILE, "rb"), purpose="fine-tune"
)

print("Training file ID:", training_response.id)
print("Validation file ID:", validation_response.id)

5. ファインチューニング実施関数定義

ファインチューニングを実施する関数の定義。

def train(client, parameters: dict, logger) -> float:

    # モデル名からハイパーパラメータがわかるように設定
    lrm = str(int(parameters["learning_rate_multiplier"]*1000))
    suffix = "lr"+lrm+"-b"+str(parameters["batch_size"])+"-e"+str(parameters["n_epochs"])
    response = client.fine_tuning.jobs.create(
        training_file=train_file,
        validation_file=validation_file,
        model=model_name,
        suffix=suffix,
        seed=42,
        hyperparameters=parameters,
    )
    job_id = response.id
    logger.info(f"{job_id=}")

    start_time = time.time()

    # Get the status of our fine-tuning job.
    response = client.fine_tuning.jobs.retrieve(job_id)
    status = response.status

    # ファインチューニングジョブが終わるまで待機
    while status not in ["succeeded", "failed"]:
        time.sleep(60)
        response = client.fine_tuning.jobs.retrieve(job_id)
        logger.debug(response.model_dump_json(indent=2))
        logger.debug("Elapsed time: {} minutes {} seconds".format(int((time.time() - start_time) // 60), int((time.time() - start_time) % 60)))
        status = response.status
        logger.debug(f'Status: {status}')
    
    logger.info(f'Fine-tuning job {job_id} finished with status: {status}')

    response = client.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=30)

    # Full Valid Lossを取得
    for event in response.data:
        if event.type == "metrics" and event.data.get("full_valid_loss"):
            logger.info(event.data)
            return event.data["full_valid_loss"]

    # 取得できない場合は警告出力してOvbjectiveを大きい値に設定
    logger.warning(f"No metrics found for job {job_id}")
    return 999.999

6. Optuna実行処理関数定義

Optunaで実行する処理の関数定義です。ObjectiveとするFull Loss Validationを返します。

def objective(trial, client: AzureOpenAI, logger) -> float:
    parameters = {
            "learning_rate_multiplier": trial.suggest_float('learning_rate_multiplier', 1e-2, 3e-1, log=True), # from 0.01 to 0.3
            "batch_size": trial.suggest_int('batch_size', 1, 32), #from 1 to 32
            "n_epochs": trial.suggest_int('n_epochs', 1, 10),
    }
    score = train(client, parameters, logger)
    logger.info(f"Trial {trial.number} finished with score: {score}")    
    return score

7. ログ出力関数定義

処理中にログを出力するための関数です。ローカルファイルにログを吐き出しています。

def get_module_logger():
    logger = optuna.logging.get_logger("optuna")
    handler = FileHandler('./logs/optuna.log')
    handler.setLevel(DEBUG)
    handler.setFormatter(Formatter('[%(levelname)s] %(asctime)s: %(message)s'))
    logger.addHandler(handler)

    logger.propagate = False
    return logger

logger = get_module_logger()

以前、私が書いた記事を参考にしています。

8. ハイパーパラメータ探索実行

ハイパーパラメータ探索実行処理です。
処理が中断しても途中から再開できます。sqliteを使う形にしていて、再開時にsqliteに保存された内容を見に行くからです。sqliteの設定は特別する必要がありませんでした。

  • n_trials: 実行試行数
  • n_jobs: 並列実行数
%%time
study_name = "fine_tune"
study = optuna.create_study(study_name=study_name,
                            direction="minimize",
                            storage='sqlite:///../finetune_study00.db',
                            load_if_exists=True)
study.optimize(lambda trial: objective(trial, client, logger), n_trials=20, n_jobs=4)

9. 実行結果確認

9.1. 実行グラフ結果

グラフで実行結果の確認ができます。

def plot_optimization_history(study):
    fig = optuna.visualization.plot_optimization_history(study)
    fig.show()

def plot_param_importances(study):
    fig = optuna.visualization.plot_param_importances(study)
    fig.show()

def plot_hyperparameter_distributions(study):
    fig = optuna.visualization.plot_slice(study)
    fig.show()

def plot_parallel_coordinate(study):
    fig = optuna.visualization.plot_parallel_coordinate(study)
    fig.show()

def plot_edf(study):
    fig = optuna.visualization.plot_edf(study)
    fig.show()

def plot_rank(study):
    fig = optuna.visualization.plot_rank(study)
    fig.show()

plot_optimization_history(study)
plot_param_importances(study)
plot_hyperparameter_distributions(study)
plot_parallel_coordinate(study)
plot_edf(study)
plot_rank(study)

9.2. データフレーム出力

以下の関数で実行内容のデーターフレームを確認できます。

study.trials_dataframe()

その他

試行内容リカバリー方法

途中ネットワークが切れて、実行は成功しているのだけど、Python上ではエラーになったことがありました。
その時には、手でリカバリーしました。
create_trialdistributionsvalue指定している点は誤りがあるかもしれません。

trial = optuna.trial.create_trial(
    params={'learning_rate_multiplier': 0.12345, 'batch_size': 10, 'n_epochs': 3},
    distributions={"x": FloatDistribution(0, 10)},
    value=4.0,
)
for trial in study.trials:
    # print(f"Trial {trial.number}: {trial.value}, {trial.params}, {trial.state}")
        if trial.number == 21 and trial.state == optuna.trial.TrialState.FAIL:
            trial.number = 44
            trial.value = 0.565770392
            trial.state = optuna.trial.TrialState.COMPLETE
            study.add_trial(trial)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?