LoginSignup
39
35

More than 5 years have passed since last update.

scikit-learn を使った回帰モデルとその可視化

Posted at

昨日紹介した書籍のひとつ SciPy and NumPy
Optimizing & Boosting your Python Programming
から例を取り出し scikit-learn による回帰の話です。

3D モデルからの回帰平面算出

まずは 3D モデルを書くためのライブラリをインポートします。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# scikit-learn の Linear Regression を利用します
from sklearn import linear_model
# 回帰用のサンプルデータセットを使います
from sklearn.datasets.samples_generator import make_regression

サンプルデータからの訓練データ生成

サンプルデータがありますので、これを訓練データに基づいた分類をおこなってみましょう。

# 訓練と試験のための合成データを生成します
X, y = make_regression(n_samples=100, n_features=2, n_informative=1,
                       random_state=0, noise=50)
# =>
# [[ 1.05445173 -1.07075262]
#  [-0.36274117 -0.63432209]
# ...
#  [-0.17992484  1.17877957]
#  [-0.68481009  0.40234164]]
# [  -6.93224214   -4.12640648   29.47265153  -12.03166314 -121.67258636
#  -149.24989393  113.53496654   -7.83638906  190.00097568   49.48805247
# ...
#   246.92583786  171.84739934  -33.55917696   38.71008939  -28.23999523
#    39.5677481  -168.02196071 -201.18826919   69.07078178  -36.96534574]

訓練データと試験データの分割

生成したデータを訓練と試験について 80:20 の割合で分割します。

X_train, X_test = X[:80], X[-20:]
y_train, y_test = y[:80], y[-20:]

分類器の訓練

準備ができたら訓練を行います。まずは分類器のインスタンスを生成、次におなじみの .fit メソッドで分類器を訓練します。

regr = linear_model.LinearRegression()
# 訓練する
regr.fit(X_train, y_train)

# 推定値を表示する
print(regr.coef_)
#=> [-10.25691752 90.5463984 ]

値の予測

次に訓練データに基づいた y 値を予測します。

X1 = np.array([1.2, 4])
print(regr.predict(X1))
#=> 350.860363861

評価

結果を評価してみましょう。

print(regr.score(X_test, y_test))
#=> 0.949827492261

可視化

データだけでは直感的ではないので、最後に可視化をしてみます。

fig = plt.figure(figsize=(8, 5))
ax = fig.add_subplot(111, projection='3d')
# ax = Axes3D(fig)

# Data
ax.scatter(X_train[:, 0], X_train[:, 1], y_train, facecolor='#00CC00')
ax.scatter(X_test[:, 0], X_test[:, 1], y_test, facecolor='#FF7800')

coef = regr.coef_
line = lambda x1, x2: coef[0] * x1 + coef[1] * x2

grid_x1, grid_x2 = np.mgrid[-2:2:10j, -2:2:10j]
ax.plot_surface(grid_x1, grid_x2, line(grid_x1, grid_x2),
                alpha=0.1, color='k')
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.zaxis.set_visible(False)

fig.savefig('image.png', bbox='tight')

image.png

平面が求められました。だいたいあっているようですね。

まとめ

サンプルデータセットを使うと綺麗な平面を描くことができました。現実の問題においてはなかなか綺麗に行かないこともありますが、理論を抑えておくと役に立つでしょう。

39
35
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
39
35