scikit-learnでGridSearchCVを使ってパラメータをチューニングする際に、各パラメータでの結果をヒートマップで表示して観察することがよく行われる。このヒートマップをGridSearchCVオブジェクトから簡単に作成する方法を調査した。
ヒートマップを簡単に作成する方法
GridSearchCV.cv_results_(グリッドサーチの結果が格納されている)を一度pandas.DataFrameに変換して、seaborn.heatmapに渡してヒートマップを表示する。
- DataFrameに変換することでヒートマップに軸名を別途指定する必要がなくなる。
- cv_results_をDataFrameに変換するときには必要なキーのみ抽出すること。(cv_results_に含まれる非推奨のキーを参照することで警告が出るのを防ぐため。)
実装例
# GridSearchCVの結果をヒートマップで表示する。
def plot_heatmap_from_grid(grid):
# チューニング対象のパラメータを特定する。
params = [k for k in grid.cv_results_.keys() if k.startswith('param_')]
if len(params) != 2: raise Exception('grid has to have exact 2 parameters.')
# ヒートマップの行、列、値に使うキーを定義する。
index = params[0]
columns = params[1]
values = 'mean_test_score'
# gridから必要なキーのみを抽出する。
df_dict = {k: grid.cv_results_[k] for k in grid.cv_results_.keys() & {index, columns, values}}
# dictをDataFrameに変換してseabornでヒートマップを表示する。
import pandas as pd
df = pd.DataFrame(df_dict)
data = df.pivot(index=index, columns=columns, values=values)
import seaborn as sns
sns.heatmap(data, annot=True, fmt='.3f')