8
10

More than 3 years have passed since last update.

optunaを使って多目的最適化問題を解いてみた

Posted at
  • 製造業出身のデータサイエンティストがお送りする記事
  • 今回はoptunaを使って多目的最適化問題を解いてみました。

はじめに

過去に多目的最適化問題については、記事を何個か書いておりますので参考にして頂けますと幸いです。

使用するライブラリー(optuna)

今回は最適化ライブラリoptunaを使って実装しました。

optunaを使って多目的最適化を実装

今回は簡単な問題を試してみました。

# 必要なライブラリーのインストール
import optuna
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.simplefilter('ignore')

次に目的関数を定義します。

# 目的関数の定義。複数の目的変数を戻り値とする
def objective(trial):
    x = trial.suggest_uniform("x", 0, 5) # 変数xを上下限0~5の範囲で連続値
    y = trial.suggest_uniform("y", 0, 3) # 変数yを上下限0~3の範囲で連続値
    v0 = 4 * x ** 2 + 4 * y ** 2
    v1 = (x - 5) ** 2 + (y - 5) ** 2
    return v0, v1

次に最適化の条件を設定します。

# 最適化の条件設定
study = optuna.multi_objective.create_study(
    directions=["minimize", "minimize"], # "minimize" "maximize"
    sampler=optuna.multi_objective.samplers.NSGAIIMultiObjectiveSampler(seed = 1)
)
# 最適化の実行
study.optimize(objective, n_trials=200)

結果を可視化します。

# 最適化過程で得た履歴データの取得。get_trials()メソッドを使用
trials = {str(trial.values): trial for trial in study.get_trials()}
trials = list(trials.values())
# グラフにプロットするため、目的変数をリストに格納する
y1_all_list = []
y2_all_list = []
for i, trial in enumerate(trials, start=1):
    y1_all_list.append(trial.values[0])
    y2_all_list.append(trial.values[1])

# パレート解の取得。get_pareto_front_trials()メソッドを使用
trials = {str(trial.values): trial for trial in study.get_pareto_front_trials()}
trials = list(trials.values())
trials.sort(key=lambda t: t.values)
# グラフプロット用にリストで取得。またパレート解の目的変数と説明変数をcsvに保存する
y1_list = []
y2_list = []
with open('pareto_data_real.csv', 'w') as f:
    for i, trial in enumerate(trials, start=1):
        if i == 1:
            columns_name_str = 'trial_no,y1,y2'
        data_list = []
        data_list.append(trial.number)
        y1_value = trial.values[0]
        y2_value = trial.values[1]
        y1_list.append(y1_value)
        y2_list.append(y2_value)
        data_list.append(y1_value)
        data_list.append(y2_value)    
        for key, value in trial.params.items():
            data_list.append(value)
            if i == 1:
                columns_name_str += ',' + key 
        if i == 1:
            f.write(columns_name_str + '\n')
        data_list = list(map(str, data_list))
        data_list_str = ','.join(data_list)
        f.write(data_list_str + '\n')

# パレート解を図示
plt.rcParams["font.size"] = 16
plt.figure(dpi=120)
plt.title("multiobjective optimization")
plt.xlabel("Y1")
plt.ylabel("Y2")
plt.grid()
plt.scatter(y1_all_list, y2_all_list, c='blue', label='all trials')
plt.scatter(y1_list, y2_list, c='red', label='pareto front')
plt.legend()
plt.tight_layout()
plt.savefig("pareto_graph_real.png")
plt.close()

image.png

さいごに

最後まで読んで頂き、ありがとうございました。
今回は、optunaを使って多目的最適化問題を解いてみました。
optunaでは、NSGA-2を多目的最適化のアルゴリズムとして選択できました。
他にも手法が選択できましたので、時間がある時に試してみようと思います。

訂正要望がありましたら、ご連絡頂けますと幸いです。

8
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
10