LoginSignup
2
2

More than 5 years have passed since last update.

【備忘録】sklearnを使ってmnistをランダムフォレストで学習させる

Posted at

まえがき

機械学習の学習アルゴリズムの一つとして、ランダムフォレストという学習方法があると知った。
sklearnのRandomForestClassifierを使うと簡単に実装出来そうだったので実装をしてみた。
今回はmnist(28 × 28の手書き文字)を学習データとして学習させる。

以下ソースコード

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from sklearn.ensemble import RandomForestClassifier
import time

start_time = time.time()
# データの読み込み
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 小数化
x_train = x_train / 255
x_test = x_test / 255
# データ数
m_train, m_test = x_train.shape[0], x_test.shape[0]
# ベクトル化
x_train, x_test = x_train.reshape(m_train, -1), x_test.reshape(m_test, -1)
# ランダムフォレストは標準化しない

# ランダムフォレスト
# max_depth=10 <-木の最大の深さ=10 深いほど学習に時間かかるが、ある程度精度が上がる
rf = RandomForestClassifier(max_depth=10)
rf.fit(x_train, y_train)

# 経過時間
print("time[s] : ", time.time() - start_time)

# 学習データの正解率
print("Train :", rf.score(x_train, y_train))

# テストデータの正解率
print("Test :", rf.score(x_test, y_test))

実行結果

time[s] :  4.7764482498168945
Train : 0.9466166666666667
Test : 0.9296

実行結果としてCore i5のノーパソ上で4.7秒の学習で92.9%の正答率が出た。

まとめ

実装したのはいいものの、ランダムフォレストの学習アルゴリズムについてまだよく理解出来てないので、今後勉強をして、少しずつ記事を編集していきたいと思います。

2
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2