概要
- Optuna で SVM のハイパーパラメータ探索をする。
- 今回は rbf カーネルに対するハイパーパラメータ探索をする。
- SVC の内容については、参考を参照。
動作確認済み Python バージョン
Python 3.12.8
Python ライブラリのインストール
pip install scikit-learn==1.7.1
pip install optuna==4.4.0
pip install plotly==6.2.0
プログラム(svc_optuna.py)
import optuna
import pickle
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
# データ読み込み(適宜データを変えてください)
X, y = load_iris(return_X_y=True)
# 目的関数(精度最大化)
def objective(trial):
# カーネル選択
kernel = 'rbf' # デフォルトの rbs だけ試す
# カーネルも探索したい場合は例えば次のように書く
# kernel = trial.suggest_categorical('kernel', ['linear', 'rbf', 'poly', 'sigmoid'])
# 共通パラメータ
C = trial.suggest_loguniform('C', 1e-3, 1e3)
gamma = trial.suggest_loguniform('gamma', 1e-4, 1e1)
# カーネル依存パラメータ
if kernel == 'poly':
degree = trial.suggest_int('degree', 2, 5)
coef0 = trial.suggest_uniform('coef0', 0.0, 1.0)
elif kernel == 'sigmoid':
degree = 3 # 未使用(デフォルト値)
coef0 = trial.suggest_uniform('coef0', 0.0, 1.0)
else: # rbf
degree = 3 # 未使用(デフォルト値)
coef0 = 0.0
# class_weight はバランス有無で比較
class_weight = trial.suggest_categorical('class_weight', [None, 'balanced'])
# モデル定義(標準化を含めたパイプライン)
model = Pipeline([
('scaler', StandardScaler()),
('svc', SVC(
kernel=kernel,
C=C,
gamma=gamma,
degree=degree,
coef0=coef0,
class_weight=class_weight
))
])
# 交差検証でスコアを取得(平均 accuracy)
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
accuracy = cross_val_score(model, X, y, cv=cv, scoring='accuracy').mean()
return accuracy
# スタディの作成と最適化
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=200)
# 結果表示
print("Best Score:", study.best_value)
print("Best Params:")
for k, v in study.best_trial.params.items():
print(f" {k}: {v}")
# スタディの保存
with open("./study.pkl", "wb") as fout:
pickle.dump(study, fout)
# スコアの変位をプロット
fig = optuna.visualization.plot_optimization_history(study)
fig.show()
# パラメータの重要度をプロット
fig = optuna.visualization.plot_param_importances(study)
fig.show()
実行
python svc_optuna.py
結果
- 標準出力の結果は次の通りである。50 回の試行で既にスコアが収束しているように思われる。
Best Score: 0.9733333333333334
Best Params:
C: 4.67863079344352
gamma: 0.07159877091542621
class_weight: None
- 出力される画像は次の通りである。