LoginSignup
5
5

More than 5 years have passed since last update.

GridSearchCVの結果をヒートマップで表示する

Last updated at Posted at 2018-05-04

scikit-learnでGridSearchCVを使ってパラメータをチューニングする際に、各パラメータでの結果をヒートマップで表示して観察することがよく行われる。このヒートマップをGridSearchCVオブジェクトから簡単に作成する方法を調査した。

ヒートマップを簡単に作成する方法

GridSearchCV.cv_results_(グリッドサーチの結果が格納されている)を一度pandas.DataFrameに変換して、seaborn.heatmapに渡してヒートマップを表示する。

  • DataFrameに変換することでヒートマップに軸名を別途指定する必要がなくなる。
  • cv_results_をDataFrameに変換するときには必要なキーのみ抽出すること。(cv_results_に含まれる非推奨のキーを参照することで警告が出るのを防ぐため。)

実装例

# GridSearchCVの結果をヒートマップで表示する。
def plot_heatmap_from_grid(grid):
    # チューニング対象のパラメータを特定する。
    params = [k for k in grid.cv_results_.keys() if k.startswith('param_')]
    if len(params) != 2: raise Exception('grid has to have exact 2 parameters.') 

    # ヒートマップの行、列、値に使うキーを定義する。
    index = params[0]
    columns = params[1]
    values = 'mean_test_score'

    # gridから必要なキーのみを抽出する。
    df_dict = {k: grid.cv_results_[k] for k in grid.cv_results_.keys() & {index, columns, values}}

    # dictをDataFrameに変換してseabornでヒートマップを表示する。
    import pandas as pd
    df = pd.DataFrame(df_dict)
    data = df.pivot(index=index, columns=columns, values=values)
    import seaborn as sns
    sns.heatmap(data, annot=True, fmt='.3f')
5
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5