8
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

機械学習で競馬予想(その2)

Last updated at Posted at 2016-11-23

前回の続きです。

今回のテーマ

今回は簡単にランダムフォレストのパラメータのチューニングについて書きたいと思います。

パラメータについて

前回のプログラムに書いたscikit-learnのRandomForestClassifierは指定できるパラメータが数々あります。
参考 : RandomForestClassifier

このパラメータのチューニングで結果が大きく変わることもあります。
ただし精度を上げるために分岐の数を増やしたらするとトレードオフとして処理に時間がかかります。
また、増やしすぎても精度が下がることもあります。

そこで今回は最適なパラメータを決める手法を紹介します。

パラメータチューニング(Grid Search)

scikit-learnで最適なパラメータを探すにはGrid Searchが便利です。

早速ですがサンプルプラムを書きます。
(注 : 試すパラメータを増やすとめちゃめちゃ時間がかかります)

プログラム

#coding:utf-8
import csv

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV


class BestParameter:
  def __init__ (self) :
    self.horse_data = []  
    self.train_data = []  
    self.train_target = [] 
    # テスト対象
    self.test_row_no = -1 

    self.master = {
      1 : {
        "福島": 0, "小倉": 1, "京都": 2, "函館": 3,
        "中山": 4, "札幌": 5, "東京": 6,
        "阪神": 7, "中京": 8, "新潟": 9 
      },
      4 :  { "" : 0, "ダート" : 1, "障害" : 2 },
      5 :  { "" : 0, "" : 1, "" : 2, "直線" : 3, "右2周" : 4 },
      7 :  { "不良" : 0,  "" : 1, "稍重" : 2, "" : 3 },
      14 : {"" : 0, "" : 1, "せん" : 2},
      29 : {"A" : 0, "B" : 1, "C" : 2, "D" : 3, "E" : 4, "nan" : -1 }
    }

  def best_parameter(self):
    hurdle_race_count = 0
    header = []
    label = []
    with open("data/jra_race_result.csv", "r") as f:
      reader = csv.reader(f)
      # 障害は除くデータで予測データを作成
      for idx, row in enumerate(reader):
        if idx == 0:
          for i, col in enumerate(row):
            header = row
          continue
        elif row[4] == '障害' :
          hurdle_race_count += 1
          continue
        horse = []
        parameter = []
        # マスタデータで数値化
        for i, col in enumerate(row):
          if i in {3, 13, 16, 18, 19, 26, 27, 28}:
            horse.append(col)
            continue
          elif i == 0 : 
            if self.test_row_no == -1 and col == '2016-09-17' :
              self.test_row_no = (idx - hurdle_race_count)
            parameter.append(col.replace('-',''))
          elif i == 10 : 
            label.append(header[i])
            horse.append(col)
            self.train_target.append(col)
          elif self.master.has_key(i) :
            if i == 1 :
              horse.append(col)
            label.append(header[i])
            parameter.append(self.master[i][col])
          else :
            if i in (2, 12) :
              horse.append(col)
            label.append(header[i])
            if col == ''  or col == ' - ': 
              col = -1
            parameter.append(float(col))
        self.horse_data.append(horse)
        self.train_data.append(parameter)
    
    # fitで学習 (9/17までを学習)
    # modelをシリアライズする場合
    # joblib.dump(model, 'model.pkl') 
    # 素性の重要度(RandomForestの分岐での重要度)
    parameters = {
      'n_estimators'      : [5, 10],
      'max_features'      : ['auto', 'sqrt', 'log2', None],
      'max_features'      : [3, 5, 10, 15, 20],
      'random_state'      : [0],
      'n_jobs'            : [1],
      'min_samples_split' : [3, 5, 10, 15, 20, 25, 30, 40, 50],
      'max_depth'         : [3, 5, 10, 15, 20, 25, 30, 40, 50]
    }
    model = GridSearchCV(RandomForestClassifier(), parameters, n_jobs=-1)
    model.fit(self.train_data[0 : self.test_row_no - 1], self.train_target[0 : self.test_row_no - 1])

    for params, mean_score, all_scores in model.grid_scores_:
      print("{:.3f} (+/- {:.3f}) for {}".format(mean_score, all_scores.std() / 2, params))

    ## best_estimator_でscoreを表示します。
    print(clf.best_estimator_)

if __name__ == "__main__":
  best_parameter = BestParameter()
  best_parameter.best_parameter()

プログラム実行 & 結果

※ 時間が掛かるので、パラメータを減らしてます。

$ python best_parameter.py 
0.335 (+/- 0.005) for {'n_estimators': 5}
0.380 (+/- 0.006) for {'n_estimators': 10}
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=10, n_jobs=1, oob_score=False, random_state=None,
            verbose=0, warm_start=False)

このように最適なパラメータを出せます。

競馬予想SIVA
[facebook]
https://www.facebook.com/AIkeiba/
[Twitter]
https://twitter.com/Siva_keiba
随時実況していきますので、いいね!フォローお願いします。

以上、また書きます!

8
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?