12
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

OptunaをGoogle Colaboratoryで使ってbest modelと履歴をGoogle Driveに保存する

Last updated at Posted at 2021-07-05

Optunaにはstrageオプションがあって、履歴を共有することで最適化の分散処理を行うことができるのですが、さあこれをGoogle ColaboratoryとGoogle Driveでできないのか、ということで試してみたところ、できました。

pymysqlなどSQLを取り扱う必要があるのかなーと思ってしまいましたが、幸いなことに不要でした。

以下、手順例を示します。

Optuna のインストール

!pip install optuna

import optuna
Collecting optuna
[?25l  Downloading https://files.pythonhosted.org/packages/1a/18/b49ca91cf592747e19f2d333c2a86cd7c81895b922a5a09adf6335471576/optuna-2.8.0-py3-none-any.whl (301kB)
[K     |████████████████████████████████| 307kB 3.9MB/s 
[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from optuna) (20.9)
Collecting colorlog
  Downloading https://files.pythonhosted.org/packages/32/e6/e9ddc6fa1104fda718338b341e4b3dc31cd8039ab29e52fc73b508515361/colorlog-5.0.1-py2.py3-none-any.whl
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from optuna) (1.19.5)
Collecting alembic
[?25l  Downloading https://files.pythonhosted.org/packages/d5/80/ef186e599a57d0e4cb78fc76e0bfc2e6953fa9716b2a5cf2de0117ed8eb5/alembic-1.6.5-py2.py3-none-any.whl (164kB)
[K     |████████████████████████████████| 174kB 20.4MB/s 
[?25hRequirement already satisfied: sqlalchemy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from optuna) (1.4.18)
Requirement already satisfied: scipy!=1.4.0 in /usr/local/lib/python3.7/dist-packages (from optuna) (1.4.1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from optuna) (4.41.1)
Collecting cmaes>=0.8.2
  Downloading https://files.pythonhosted.org/packages/01/1f/43b01223a0366171f474320c6e966c39a11587287f098a5f09809b45e05f/cmaes-0.8.2-py3-none-any.whl
Collecting cliff
[?25l  Downloading https://files.pythonhosted.org/packages/87/11/aea1cacbd4cf8262809c4d6f95dcb3f2802594de1f51c5bd454d69bf15c5/cliff-3.8.0-py3-none-any.whl (80kB)
[K     |████████████████████████████████| 81kB 8.5MB/s 
[?25hRequirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->optuna) (2.4.7)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.7/dist-packages (from alembic->optuna) (2.8.1)
Collecting python-editor>=0.3
  Downloading https://files.pythonhosted.org/packages/c6/d3/201fc3abe391bbae6606e6f1d598c15d367033332bd54352b12f35513717/python_editor-1.0.4-py3-none-any.whl
Collecting Mako
[?25l  Downloading https://files.pythonhosted.org/packages/f3/54/dbc07fbb20865d3b78fdb7cf7fa713e2cba4f87f71100074ef2dc9f9d1f7/Mako-1.1.4-py2.py3-none-any.whl (75kB)
[K     |████████████████████████████████| 81kB 8.9MB/s 
[?25hRequirement already satisfied: greenlet!=0.4.17; python_version >= "3" in /usr/local/lib/python3.7/dist-packages (from sqlalchemy>=1.1.0->optuna) (1.1.0)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from sqlalchemy>=1.1.0->optuna) (4.5.0)
Collecting pbr!=2.1.0,>=2.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/18/e0/1d4702dd81121d04a477c272d47ee5b6bc970d1a0990b11befa275c55cf2/pbr-5.6.0-py2.py3-none-any.whl (111kB)
[K     |████████████████████████████████| 112kB 18.7MB/s 
[?25hCollecting cmd2>=1.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/e3/6a/e929ec70ca05c5962f6541ef29fb9c207dd41f0f2333680fa39f44fa4357/cmd2-2.1.1-py3-none-any.whl (140kB)
[K     |████████████████████████████████| 143kB 18.5MB/s 
[?25hRequirement already satisfied: PyYAML>=3.12 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna) (3.13)
Requirement already satisfied: PrettyTable>=0.7.2 in /usr/local/lib/python3.7/dist-packages (from cliff->optuna) (2.1.0)
Collecting stevedore>=2.0.1
[?25l  Downloading https://files.pythonhosted.org/packages/d4/49/b602307aeac3df3384ff1fcd05da9c0376c622a6c48bb5325f28ab165b57/stevedore-3.3.0-py3-none-any.whl (49kB)
[K     |████████████████████████████████| 51kB 5.9MB/s 
[?25hRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil->alembic->optuna) (1.15.0)
Requirement already satisfied: MarkupSafe>=0.9.2 in /usr/local/lib/python3.7/dist-packages (from Mako->alembic->optuna) (2.0.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < "3.8"->sqlalchemy>=1.1.0->optuna) (3.4.1)
Requirement already satisfied: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < "3.8"->sqlalchemy>=1.1.0->optuna) (3.7.4.3)
Collecting colorama>=0.3.7
  Downloading https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl
Requirement already satisfied: wcwidth>=0.1.7 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna) (0.2.5)
Requirement already satisfied: attrs>=16.3.0 in /usr/local/lib/python3.7/dist-packages (from cmd2>=1.0.0->cliff->optuna) (21.2.0)
Collecting pyperclip>=1.6
  Downloading https://files.pythonhosted.org/packages/a7/2c/4c64579f847bd5d539803c8b909e54ba087a79d01bb3aba433a95879a6c5/pyperclip-1.8.2.tar.gz
Building wheels for collected packages: pyperclip
  Building wheel for pyperclip (setup.py) ... [?25l[?25hdone
  Created wheel for pyperclip: filename=pyperclip-1.8.2-cp37-none-any.whl size=11136 sha256=4dc9b64db7863c0f74ada694cbda69c409892a0a919f9a11fef03fca82dcd886
  Stored in directory: /root/.cache/pip/wheels/25/af/b8/3407109267803f4015e1ee2ff23be0c8c19ce4008665931ee1
Successfully built pyperclip
Installing collected packages: colorlog, python-editor, Mako, alembic, cmaes, pbr, colorama, pyperclip, cmd2, stevedore, cliff, optuna
Successfully installed Mako-1.1.4 alembic-1.6.5 cliff-3.8.0 cmaes-0.8.2 cmd2-2.1.1 colorama-0.4.4 colorlog-5.0.1 optuna-2.8.0 pbr-5.6.0 pyperclip-1.8.2 python-editor-1.0.4 stevedore-3.3.0

Google Colaboratory から Google Drive へのマウント

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

保存場所の作成

import os

directory_path = './drive/MyDrive/test_optuna_colab_db/'
if not os.path.exists(directory_path):
    os.makedirs(directory_path)

機械学習のためのデータ

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split 

breast_cancer = load_breast_cancer()
X = breast_cancer.data
y = breast_cancer.target.ravel()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.4) 

目的関数の定義とハイパーパラメーターの設定

ここで、ベストモデルの保存も行えるようにしています。

import numpy as np
import lightgbm as lgb
import pickle

class Objective:
    def __init__(self, study, directory_path):
        self.direction = study.direction
        self.best_score = None
        self.best_model = None
        self.directory_path = directory_path

    def __call__(self, trial):
        learning_rate = trial.suggest_loguniform('learning_rate', 0.1,0.2),
        n_estimators, = trial.suggest_int('n_estimators', 20, 200),
        max_depth, = trial.suggest_int('max_depth', 3, 9),
        min_child_weight = trial.suggest_loguniform('min_child_weight', 0.5, 2),
        min_child_samples, = trial.suggest_int('min_child_samples', 5, 20),
        model = lgb.LGBMClassifier(learning_rate=learning_rate, 
                                        n_estimators=n_estimators,
                                        max_depth=max_depth, 
                                        min_child_weight=min_child_weight,
                                        min_child_samples=min_child_samples,
                                        subsample=0.8, colsample_bytree=0.8,
                                        verbose=-1, num_leaves=80)
        model.fit(X_train, y_train)
        score = np.linalg.norm(y_train - model.predict_proba(X_train)[:, 1], ord=1)

        if self.best_score is None:
            self.save_best_model(model, score)

        if self.direction == optuna.study.StudyDirection(1): # minimize
            if self.best_score > score:
                self.save_best_model(model, score)
        else:
            if self.best_score < score:
                self.save_best_model(model, score)

        return score

    def save_best_model(self, model, score):
        self.best_score = score
        self.best_model = model
        with open(self.directory_path + 'best_model.pkl', 'wb') as obj:
            pickle.dump(model , obj)

create_study

次のようにオプションを指定すれば、Google Drive上の指定のファイルにOptunaの履歴を残せるようになります。

strage_name = "optuna_strage.sql"
study_name = 'example-study'
study = optuna.create_study(
    study_name = study_name,
    storage='sqlite:///' + directory_path + strage_name, 
    load_if_exists=True,
    direction='minimize')
[32m[I 2021-07-05 05:50:14,078][0m A new study created in RDB with name: example-study[0m

初めて作成する場合は上のようなメッセージが出ます。

実行

objective = Objective(study, directory_path)
study.optimize(objective, timeout=5)
[32m[I 2021-07-05 05:50:48,556][0m Trial 0 finished with value: 2.356739567027545 and parameters: {'learning_rate': 0.1329737694584257, 'n_estimators': 144, 'max_depth': 3, 'min_child_weight': 1.0664725230880512, 'min_child_samples': 9}. Best is trial 0 with value: 2.356739567027545.[0m
[32m[I 2021-07-05 05:50:48,795][0m Trial 1 finished with value: 1.2751185388688262 and parameters: {'learning_rate': 0.136003087726788, 'n_estimators': 187, 'max_depth': 4, 'min_child_weight': 0.6168448838920242, 'min_child_samples': 6}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:49,138][0m Trial 2 finished with value: 1.902826529992414 and parameters: {'learning_rate': 0.1349201769249344, 'n_estimators': 195, 'max_depth': 6, 'min_child_weight': 0.9073287829778793, 'min_child_samples': 11}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:49,520][0m Trial 3 finished with value: 1.5688048026680932 and parameters: {'learning_rate': 0.15455236973136044, 'n_estimators': 132, 'max_depth': 4, 'min_child_weight': 0.7505318184611741, 'min_child_samples': 5}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:49,794][0m Trial 4 finished with value: 5.473157716123604 and parameters: {'learning_rate': 0.13896503420806364, 'n_estimators': 40, 'max_depth': 9, 'min_child_weight': 1.0588163014144807, 'min_child_samples': 15}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:50,196][0m Trial 5 finished with value: 3.4970329848502684 and parameters: {'learning_rate': 0.12139535041307536, 'n_estimators': 145, 'max_depth': 8, 'min_child_weight': 1.5720596652676253, 'min_child_samples': 9}. Best is trial 1 with value: 1.2751185388688262.[0m
[32m[I 2021-07-05 05:50:50,489][0m Trial 6 finished with value: 1.0347527068253135 and parameters: {'learning_rate': 0.11931876764006037, 'n_estimators': 194, 'max_depth': 9, 'min_child_weight': 0.503628096796818, 'min_child_samples': 8}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:50,787][0m Trial 7 finished with value: 2.8733882065744103 and parameters: {'learning_rate': 0.19845644375161245, 'n_estimators': 78, 'max_depth': 9, 'min_child_weight': 1.2666195605297148, 'min_child_samples': 15}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:51,171][0m Trial 8 finished with value: 7.4862918479992455 and parameters: {'learning_rate': 0.18904478778073058, 'n_estimators': 25, 'max_depth': 6, 'min_child_weight': 1.547647464095838, 'min_child_samples': 7}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:51,578][0m Trial 9 finished with value: 3.1729588964953788 and parameters: {'learning_rate': 0.10176635813025851, 'n_estimators': 176, 'max_depth': 9, 'min_child_weight': 1.4044663576887746, 'min_child_samples': 20}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:51,991][0m Trial 10 finished with value: 1.7777210037788729 and parameters: {'learning_rate': 0.10118251285670081, 'n_estimators': 85, 'max_depth': 7, 'min_child_weight': 0.5102011440667872, 'min_child_samples': 14}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:52,265][0m Trial 11 finished with value: 1.1045896889115712 and parameters: {'learning_rate': 0.11475193090668549, 'n_estimators': 195, 'max_depth': 4, 'min_child_weight': 0.5366112521913312, 'min_child_samples': 5}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:52,676][0m Trial 12 finished with value: 1.0379864720722323 and parameters: {'learning_rate': 0.11284320711236966, 'n_estimators': 169, 'max_depth': 4, 'min_child_weight': 0.5049889093073401, 'min_child_samples': 5}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:53,106][0m Trial 13 finished with value: 1.4451824436141336 and parameters: {'learning_rate': 0.11290913100885246, 'n_estimators': 163, 'max_depth': 5, 'min_child_weight': 0.6714594028344404, 'min_child_samples': 9}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:50:53,510][0m Trial 14 finished with value: 1.1868961883560274 and parameters: {'learning_rate': 0.15916412146889425, 'n_estimators': 108, 'max_depth': 3, 'min_child_weight': 0.5251964554559795, 'min_child_samples': 7}. Best is trial 6 with value: 1.0347527068253135.[0m

はい、ということで1回目の最適化計算ができました。

別のノートブックから動かすことを想定

別のノートブックから動かすことを想定して、先ほどと全く同じ命令を実行します。

strage_name = "optuna_strage.sql"
study_name = 'example-study'
study = optuna.create_study(
    study_name = study_name,
    storage='sqlite:///' + directory_path + strage_name, 
    load_if_exists=True,
    direction='minimize') # 最小化
[32m[I 2021-07-05 05:51:25,634][0m Using an existing study with name 'example-study' instead of creating a new one.[0m

既に履歴ファイルがある場合は上のようなメッセージが出ます。

実行

objective = Objective(study, directory_path)
study.optimize(objective, timeout=5)
[32m[I 2021-07-05 05:51:37,279][0m Trial 15 finished with value: 4.375297095641079 and parameters: {'learning_rate': 0.11083671508676562, 'n_estimators': 169, 'max_depth': 7, 'min_child_weight': 1.9410177835395233, 'min_child_samples': 11}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:37,578][0m Trial 16 finished with value: 1.626356885892271 and parameters: {'learning_rate': 0.1229692220545294, 'n_estimators': 200, 'max_depth': 5, 'min_child_weight': 0.7793190426176172, 'min_child_samples': 18}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:37,987][0m Trial 17 finished with value: 1.5271207713964778 and parameters: {'learning_rate': 0.10507897856058691, 'n_estimators': 123, 'max_depth': 5, 'min_child_weight': 0.6174962287555804, 'min_child_samples': 8}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:38,381][0m Trial 18 finished with value: 1.1013763872183016 and parameters: {'learning_rate': 0.1196729352067664, 'n_estimators': 157, 'max_depth': 8, 'min_child_weight': 0.5148082431230212, 'min_child_samples': 11}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:38,795][0m Trial 19 finished with value: 1.2498171795395772 and parameters: {'learning_rate': 0.15272857601741555, 'n_estimators': 179, 'max_depth': 3, 'min_child_weight': 0.6054594596447477, 'min_child_samples': 5}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:39,087][0m Trial 20 finished with value: 2.070989398053077 and parameters: {'learning_rate': 0.10723480051132261, 'n_estimators': 113, 'max_depth': 7, 'min_child_weight': 0.801881930872385, 'min_child_samples': 7}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:39,543][0m Trial 21 finished with value: 1.0868236265530413 and parameters: {'learning_rate': 0.12332630452324375, 'n_estimators': 154, 'max_depth': 8, 'min_child_weight': 0.5026395575011376, 'min_child_samples': 11}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:39,983][0m Trial 22 finished with value: 1.2150657340703361 and parameters: {'learning_rate': 0.1244639125603962, 'n_estimators': 152, 'max_depth': 8, 'min_child_weight': 0.5612743648886653, 'min_child_samples': 12}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:40,439][0m Trial 23 finished with value: 1.4234063287737972 and parameters: {'learning_rate': 0.12881256406386032, 'n_estimators': 178, 'max_depth': 8, 'min_child_weight': 0.6863433594892258, 'min_child_samples': 9}. Best is trial 6 with value: 1.0347527068253135.[0m
[32m[I 2021-07-05 05:51:40,901][0m Trial 24 finished with value: 1.0312969226538424 and parameters: {'learning_rate': 0.11686319575294582, 'n_estimators': 200, 'max_depth': 9, 'min_child_weight': 0.5022939601479719, 'min_child_samples': 13}. Best is trial 24 with value: 1.0312969226538424.[0m
[32m[I 2021-07-05 05:51:41,359][0m Trial 25 finished with value: 1.181746641164211 and parameters: {'learning_rate': 0.11560369926886244, 'n_estimators': 195, 'max_depth': 9, 'min_child_weight': 0.5736122888762535, 'min_child_samples': 13}. Best is trial 24 with value: 1.0312969226538424.[0m
[32m[I 2021-07-05 05:51:41,664][0m Trial 26 finished with value: 1.4164596689410804 and parameters: {'learning_rate': 0.14542346313222995, 'n_estimators': 200, 'max_depth': 9, 'min_child_weight': 0.681448040661294, 'min_child_samples': 16}. Best is trial 24 with value: 1.0312969226538424.[0m
[32m[I 2021-07-05 05:51:42,132][0m Trial 27 finished with value: 1.9020898722583752 and parameters: {'learning_rate': 0.10824526779236447, 'n_estimators': 184, 'max_depth': 7, 'min_child_weight': 0.8946273336639666, 'min_child_samples': 17}. Best is trial 24 with value: 1.0312969226538424.[0m

はい、無事、これまでの履歴を受けて動いているようです。

ベストモデルの利用

以上のコードでは、ベストモデルの保存も自動で行っていました。それを後から呼び出して使いたい場合は

with open(directory_path + "best_model.pkl", 'rb') as pkl:
    model = pickle.load(pkl)

model
LGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=0.8,
               importance_type='split', learning_rate=(0.11686319575294582,),
               max_depth=9, min_child_samples=13,
               min_child_weight=(0.5022939601479719,), min_split_gain=0.0,
               n_estimators=200, n_jobs=-1, num_leaves=80, objective=None,
               random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True,
               subsample=0.8, subsample_for_bin=200000, subsample_freq=0,
               verbose=-1)

という感じでできますね。

小並感

SQL使ってるらしいけどSQL知らなくても使えて便利。Google Colaboratory + Drive で、ローカル環境を整えなくても使えて便利。

12
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?