Optunaにはstrageオプションがあって、履歴を共有することで最適化の分散処理を行うことができるのですが、さあこれをGoogle ColaboratoryとGoogle Driveでできないのか、ということで試してみたところ、できました。
pymysqlなどSQLを取り扱う必要があるのかなーと思ってしまいましたが、幸いなことに不要でした。
以下、手順例を示します。
Optuna のインストール
!pip install optuna
import optuna
Collecting optuna
[?25l Downloading https://files.pythonhosted.org/packages/1a/18/b49ca91cf592747e19f2d333c2a86cd7c81895b922a5a09adf6335471576/optuna-2.8.0-py3-none-any.whl (301kB)
[K |████████████████████████████████| 307kB 3.9MB/s
[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from optuna) (20.9)
Collecting colorlog
Downloading https://files.pythonhosted.org/packages/32/e6/e9ddc6fa1104fda718338b341e4b3dc31cd8039ab29e52fc73b508515361/colorlog-5.0.1-py2.py3-none-any.whl
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from optuna) (1.19.5)
Collecting alembic
[?25l Downloading https://files.pythonhosted.org/packages/d5/80/ef186e599a57d0e4cb78fc76e0bfc2e6953fa9716b2a5cf2de0117ed8eb5/alembic-1.6.5-py2.py3-none-any.whl (164kB)
[K |████████████████████████████████| 174kB 20.4MB/s
[?25hRequirement already satisfied: sqlalchemy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from optuna) (1.4.18)
Requirement already satisfied: scipy!=1.4.0 in /usr/local/lib/python3.7/dist-packages (from optuna) (1.4.1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from optuna) (4.41.1)
Collecting cmaes>=0.8.2
Downloading https://files.pythonhosted.org/packages/01/1f/43b01223a0366171f474320c6e966c39a11587287f098a5f09809b45e05f/cmaes-0.8.2-py3-none-any.whl
Collecting cliff
[?25l Downloading https://files.pythonhosted.org/packages/87/11/aea1cacbd4cf8262809c4d6f95dcb3f2802594de1f51c5bd454d69bf15c5/cliff-3.8.0-py3-none-any.whl (80kB)
[K |████████████████████████████████| 81kB 8.5MB/s
[?25hRequirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->optuna) (2.4.7)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.7/dist-packages (from alembic->optuna) (2.8.1)
Collecting python-editor>=0.3
Downloading https://files.pythonhosted.org/packages/c6/d3/201fc3abe391bbae6606e6f1d598c15d367033332bd54352b12f35513717/python_editor-1.0.4-py3-none-any.whl
Collecting Mako
[?25l Downloading https://files.pythonhosted.org/packages/f3/54/dbc07fbb20865d3b78fdb7cf7fa713e2cba4f87f71100074ef2dc9f9d1f7/Mako-1.1.4-py2.py3-none-any.whl (75kB)
[K |████████████████████████████████| 81kB 8.9MB/s
[?25hRequirement already satisfied: greenlet!=0.4.17; python_version >= "3" in /usr/local/lib/python3.7/dist-packages (from sqlalchemy>=1.1.0->optuna) (1.1.0)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from sqlalchemy>=1.1.0->optuna) (4.5.0)
Collecting pbr!=2.1.0,>=2.0.0
[?25l Downloading https://files.pythonhosted.org/packages/18/e0/1d4702dd81121d04a477c272d47ee5b6bc970d1a0990b11befa275c55cf2/pbr-5.6.0-py2.py3-none-any.whl (111kB)
[K |████████████████████████████████| 112kB 18.7MB/s
[?25hCollecting cmd2>=1.0.0
[?25l Downloading https://files.pythonhosted.org/packages/e3/6a/e929ec70ca05c5962f6541ef29fb9c207dd41f0f2333680fa39f44fa4357/cmd2-2.1.1-py3-none-any.whl (140kB)
[K |████████████████████████████████| 143kB 18.5MB/s
[?25hRequirement already satisfied: PyYAML>=3.12 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna) (3.13)
Requirement already satisfied: PrettyTable>=0.7.2 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna) (2.1.0)
Collecting stevedore>=2.0.1
[?25l Downloading https://files.pythonhosted.org/packages/d4/49/b602307aeac3df3384ff1fcd05da9c0376c622a6c48bb5325f28ab165b57/stevedore-3.3.0-py3-none-any.whl (49kB)
[K |████████████████████████████████| 51kB 5.9MB/s
[?25hRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil->alembic->optuna) (1.15.0)
Requirement already satisfied: MarkupSafe>=0.9.2 in /usr/local/lib/python3.7/dist-packages (from Mako->alembic->optuna) (2.0.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < "3.8"->sqlalchemy>=1.1.0->optuna) (3.4.1)
Requirement already satisfied: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < "3.8"->sqlalchemy>=1.1.0->optuna) (3.7.4.3)
Collecting colorama>=0.3.7
Downloading https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl
Requirement already satisfied: wcwidth>=0.1.7 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna) (0.2.5)
Requirement already satisfied: attrs>=16.3.0 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna) (21.2.0)
Collecting pyperclip>=1.6
Downloading https://files.pythonhosted.org/packages/a7/2c/4c64579f847bd5d539803c8b909e54ba087a79d01bb3aba433a95879a6c5/pyperclip-1.8.2.tar.gz
Building wheels for collected packages: pyperclip
Building wheel for pyperclip (setup.py) ... [?25l[?25hdone
Created wheel for pyperclip: filename=pyperclip-1.8.2-cp37-none-any.whl size=11136 sha256=4dc9b64db7863c0f74ada694cbda69c409892a0a919f9a11fef03fca82dcd886
Stored in directory: /root/.cache/pip/wheels/25/af/b8/3407109267803f4015e1ee2ff23be0c8c19ce4008665931ee1
Successfully built pyperclip
Installing collected packages: colorlog, python-editor, Mako, alembic, cmaes, pbr, colorama, pyperclip, cmd2, stevedore, cliff, optuna
Successfully installed Mako-1.1.4 alembic-1.6.5 cliff-3.8.0 cmaes-0.8.2 cmd2-2.1.1 colorama-0.4.4 colorlog-5.0.1 optuna-2.8.0 pbr-5.6.0 pyperclip-1.8.2 python-editor-1.0.4 stevedore-3.3.0
Google Colaboratory から Google Drive へのマウント
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
保存場所の作成
import os
directory_path = './drive/MyDrive/test_optuna_colab_db/'
if not os.path.exists(directory_path):
os.makedirs(directory_path)
機械学習のためのデータ
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
breast_cancer = load_breast_cancer()
X = breast_cancer.data
y = breast_cancer.target.ravel()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.4)
目的関数の定義とハイパーパラメーターの設定
ここで、ベストモデルの保存も行えるようにしています。
import numpy as np
import lightgbm as lgb
import pickle
class Objective:
def __init__(self, study, directory_path):
self.direction = study.direction
self.best_score = None
self.best_model = None
self.directory_path = directory_path
def __call__(self, trial):
learning_rate = trial.suggest_loguniform('learning_rate', 0.1,0.2),
n_estimators, = trial.suggest_int('n_estimators', 20, 200),
max_depth, = trial.suggest_int('max_depth', 3, 9),
min_child_weight = trial.suggest_loguniform('min_child_weight', 0.5, 2),
min_child_samples, = trial.suggest_int('min_child_samples', 5, 20),
model = lgb.LGBMClassifier(learning_rate=learning_rate,
n_estimators=n_estimators,
max_depth=max_depth,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=0.8, colsample_bytree=0.8,
verbose=-1, num_leaves=80)
model.fit(X_train, y_train)
score = np.linalg.norm(y_train - model.predict_proba(X_train)[:, 1], ord=1)
if self.best_score is None:
self.save_best_model(model, score)
if self.direction == optuna.study.StudyDirection(1): # minimize
if self.best_score > score:
self.save_best_model(model, score)
else:
if self.best_score < score:
self.save_best_model(model, score)
return score
def save_best_model(self, model, score):
self.best_score = score
self.best_model = model
with open(self.directory_path + 'best_model.pkl', 'wb') as obj:
pickle.dump(model , obj)
create_study
次のようにオプションを指定すれば、Google Drive上の指定のファイルにOptunaの履歴を残せるようになります。
strage_name = "optuna_strage.sql"
study_name = 'example-study'
study = optuna.create_study(
study_name = study_name,
storage='sqlite:///' + directory_path + strage_name,
load_if_exists=True,
direction='minimize')
[32m[I 2021-07-05 05:50:14,078][0m A new study created in RDB with name: example-study[0m
初めて作成する場合は上のようなメッセージが出ます。
実行
objective = Objective(study, directory_path)
study.optimize(objective, timeout=5)
[32m[I 2021-07-05 05:50:48,556][0m Trial 0 finished with value: 2.356739567027545 and parameters: {'learning_rate': 0.1329737694584257, 'n_estimators': 144, 'max_depth': 3, 'min_child_weight': 1.0664725230880512, 'min_child_samples': 9}. Best is trial 0 with value: 2.356739567027545.[0m
[32m[I 2021-07-05 05:50:48,795][0m Trial 1 finished with value: 1.2751185388688262 and parameters: {'learning_rate': 0.136003087726788, 'n_estimators': 187, 'max_depth': 4, 'min_child_weight': 0.6168448838920242, 'min_child_samples': 6}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:49,138][0m Trial 2 finished with value: 1.902826529992414 and parameters: {'learning_rate': 0.1349201769249344, 'n_estimators': 195, 'max_depth': 6, 'min_child_weight': 0.9073287829778793, 'min_child_samples': 11}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:49,520][0m Trial 3 finished with value: 1.5688048026680932 and parameters: {'learning_rate': 0.15455236973136044, 'n_estimators': 132, 'max_depth': 4, 'min_child_weight': 0.7505318184611741, 'min_child_samples': 5}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:49,794][0m Trial 4 finished with value: 5.473157716123604 and parameters: {'learning_rate': 0.13896503420806364, 'n_estimators': 40, 'max_depth': 9, 'min_child_weight': 1.0588163014144807, 'min_child_samples': 15}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:50,196][0m Trial 5 finished with value: 3.4970329848502684 and parameters: {'learning_rate': 0.12139535041307536, 'n_estimators': 145, 'max_depth': 8, 'min_child_weight': 1.5720596652676253, 'min_child_samples': 9}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:50,489][0m Trial 6 finished with value: 1.0347527068253135 and parameters: {'learning_rate': 0.11931876764006037, 'n_estimators': 194, 'max_depth': 9, 'min_child_weight': 0.503628096796818, 'min_child_samples': 8}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:50,787][0m Trial 7 finished with value: 2.8733882065744103 and parameters: {'learning_rate': 0.19845644375161245, 'n_estimators': 78, 'max_depth': 9, 'min_child_weight': 1.2666195605297148, 'min_child_samples': 15}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:51,171][0m Trial 8 finished with value: 7.4862918479992455 and parameters: {'learning_rate': 0.18904478778073058, 'n_estimators': 25, 'max_depth': 6, 'min_child_weight': 1.547647464095838, 'min_child_samples': 7}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:51,578][0m Trial 9 finished with value: 3.1729588964953788 and parameters: {'learning_rate': 0.10176635813025851, 'n_estimators': 176, 'max_depth': 9, 'min_child_weight': 1.4044663576887746, 'min_child_samples': 20}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:51,991][0m Trial 10 finished with value: 1.7777210037788729 and parameters: {'learning_rate': 0.10118251285670081, 'n_estimators': 85, 'max_depth': 7, 'min_child_weight': 0.5102011440667872, 'min_child_samples': 14}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:52,265][0m Trial 11 finished with value: 1.1045896889115712 and parameters: {'learning_rate': 0.11475193090668549, 'n_estimators': 195, 'max_depth': 4, 'min_child_weight': 0.5366112521913312, 'min_child_samples': 5}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:52,676][0m Trial 12 finished with value: 1.0379864720722323 and parameters: {'learning_rate': 0.11284320711236966, 'n_estimators': 169, 'max_depth': 4, 'min_child_weight': 0.5049889093073401, 'min_child_samples': 5}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:53,106][0m Trial 13 finished with value: 1.4451824436141336 and parameters: {'learning_rate': 0.11290913100885246, 'n_estimators': 163, 'max_depth': 5, 'min_child_weight': 0.6714594028344404, 'min_child_samples': 9}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:53,510][0m Trial 14 finished with value: 1.1868961883560274 and parameters: {'learning_rate': 0.15916412146889425, 'n_estimators': 108, 'max_depth': 3, 'min_child_weight': 0.5251964554559795, 'min_child_samples': 7}. Best is trial 6 with value: 1.0347527068253135.[0m
はい、ということで1回目の最適化計算ができました。
別のノートブックから動かすことを想定
別のノートブックから動かすことを想定して、先ほどと全く同じ命令を実行します。
strage_name = "optuna_strage.sql"
study_name = 'example-study'
study = optuna.create_study(
study_name = study_name,
storage='sqlite:///' + directory_path + strage_name,
load_if_exists=True,
direction='minimize') # 最小化
[32m[I 2021-07-05 05:51:25,634][0m Using an existing study with name 'example-study' instead of creating a new one.[0m
既に履歴ファイルがある場合は上のようなメッセージが出ます。
実行
objective = Objective(study, directory_path)
study.optimize(objective, timeout=5)
[32m[I 2021-07-05 05:51:37,279][0m Trial 15 finished with value: 4.375297095641079 and parameters: {'learning_rate': 0.11083671508676562, 'n_estimators': 169, 'max_depth': 7, 'min_child_weight': 1.9410177835395233, 'min_child_samples': 11}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:37,578][0m Trial 16 finished with value: 1.626356885892271 and parameters: {'learning_rate': 0.1229692220545294, 'n_estimators': 200, 'max_depth': 5, 'min_child_weight': 0.7793190426176172, 'min_child_samples': 18}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:37,987][0m Trial 17 finished with value: 1.5271207713964778 and parameters: {'learning_rate': 0.10507897856058691, 'n_estimators': 123, 'max_depth': 5, 'min_child_weight': 0.6174962287555804, 'min_child_samples': 8}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:38,381][0m Trial 18 finished with value: 1.1013763872183016 and parameters: {'learning_rate': 0.1196729352067664, 'n_estimators': 157, 'max_depth': 8, 'min_child_weight': 0.5148082431230212, 'min_child_samples': 11}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:38,795][0m Trial 19 finished with value: 1.2498171795395772 and parameters: {'learning_rate': 0.15272857601741555, 'n_estimators': 179, 'max_depth': 3, 'min_child_weight': 0.6054594596447477, 'min_child_samples': 5}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:39,087][0m Trial 20 finished with value: 2.070989398053077 and parameters: {'learning_rate': 0.10723480051132261, 'n_estimators': 113, 'max_depth': 7, 'min_child_weight': 0.801881930872385, 'min_child_samples': 7}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:39,543][0m Trial 21 finished with value: 1.0868236265530413 and parameters: {'learning_rate': 0.12332630452324375, 'n_estimators': 154, 'max_depth': 8, 'min_child_weight': 0.5026395575011376, 'min_child_samples': 11}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:39,983][0m Trial 22 finished with value: 1.2150657340703361 and parameters: {'learning_rate': 0.1244639125603962, 'n_estimators': 152, 'max_depth': 8, 'min_child_weight': 0.5612743648886653, 'min_child_samples': 12}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:40,439][0m Trial 23 finished with value: 1.4234063287737972 and parameters: {'learning_rate': 0.12881256406386032, 'n_estimators': 178, 'max_depth': 8, 'min_child_weight': 0.6863433594892258, 'min_child_samples': 9}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:40,901][0m Trial 24 finished with value: 1.0312969226538424 and parameters: {'learning_rate': 0.11686319575294582, 'n_estimators': 200, 'max_depth': 9, 'min_child_weight': 0.5022939601479719, 'min_child_samples': 13}. Best is trial 24 with value: 1.0312969226538424.[0m
[32m[I 2021-07-05 05:51:41,359][0m Trial 25 finished with value: 1.181746641164211 and parameters: {'learning_rate': 0.11560369926886244, 'n_estimators': 195, 'max_depth': 9, 'min_child_weight': 0.5736122888762535, 'min_child_samples': 13}. Best is trial 24 with value: 1.0312969226538424.[0m
[32m[I 2021-07-05 05:51:41,664][0m Trial 26 finished with value: 1.4164596689410804 and parameters: {'learning_rate': 0.14542346313222995, 'n_estimators': 200, 'max_depth': 9, 'min_child_weight': 0.681448040661294, 'min_child_samples': 16}. Best is trial 24 with value: 1.0312969226538424.[0m
[32m[I 2021-07-05 05:51:42,132][0m Trial 27 finished with value: 1.9020898722583752 and parameters: {'learning_rate': 0.10824526779236447, 'n_estimators': 184, 'max_depth': 7, 'min_child_weight': 0.8946273336639666, 'min_child_samples': 17}. Best is trial 24 with value: 1.0312969226538424.[0m
はい、無事、これまでの履歴を受けて動いているようです。
ベストモデルの利用
以上のコードでは、ベストモデルの保存も自動で行っていました。それを後から呼び出して使いたい場合は
with open(directory_path + "best_model.pkl", 'rb') as pkl:
model = pickle.load(pkl)
model
LGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=0.8,
importance_type='split', learning_rate=(0.11686319575294582,),
max_depth=9, min_child_samples=13,
min_child_weight=(0.5022939601479719,), min_split_gain=0.0,
n_estimators=200, n_jobs=-1, num_leaves=80, objective=None,
random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True,
subsample=0.8, subsample_for_bin=200000, subsample_freq=0,
verbose=-1)
という感じでできますね。
小並感
SQL使ってるらしいけどSQL知らなくても使えて便利。Google Colaboratory + Drive で、ローカル環境を整えなくても使えて便利。