6
2

More than 5 years have passed since last update.

Heat Map for Grid Search with Matplotlib

Last updated at Posted at 2017-07-23

Grid Search

http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
First, for grid search method, you need to select which parameters are used for the optimization and define parameter sets.

from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer

learner = RandomForestClassifier(random_state = 2)
n_estimators = [12, 24, 36, 48, 60]
min_samples_leaf = [1, 2, 4, 8, 16]
parameters = {'n_estimators': n_estimators, 'min_samples_leaf': min_samples_leaf}

In this case, AUC is used as a scorer. Thus, you need to create you own scorer for AUC.

def auc_scorer(target_score, prediction):
    auc_value = roc_auc_score(prediction, target_score)    
    return auc_value

scorer = make_scorer(auc_scorer, greater_is_better=True)

Finally, you can define Grid Search Object.

grid_obj = GridSearchCV(learner, parameters,  scorer)

Heat Map

http://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html
To create a heat map, you need to have 2 dimentional matrix at first. From Grid Search Object, you can retrieve all prediction results corresponding to grid search parameter set. In the example below, all result are put into scores .

scores = grid_obj.cv_results_['mean_test_score'].reshape(len(n_estimators),len(min_samples_leaf))

Note: scores contains the following array.

[[ 0.91803961  0.92444425  0.9264368   0.92730609  0.92808348]
 [ 0.91263539  0.91757799  0.91892211  0.91957058  0.91950196]
 [ 0.90143663  0.90590379  0.90669241  0.90751479  0.90758263]
 [ 0.89168321  0.89370183  0.89414698  0.89497685  0.89506426]
 [ 0.88276445  0.88386261  0.88380793  0.88408826  0.88448689]]

Then, you can use scores for plotting a heat map.

plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('n_estimators')
plt.ylabel('min_samples_leaf')
plt.colorbar()
plt.xticks(np.arange(len(n_estimators)), n_estimators)
plt.yticks(np.arange(len(min_samples_leaf)), min_samples_leaf)
plt.title('Grid Search AUC Score')
plt.show()

Finally, you can plot a heat map like below.

image

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2