昨日紹介した書籍のひとつ SciPy and NumPy
Optimizing & Boosting your Python Programming から例を取り出し scikit-learn による回帰の話です。
3D モデルからの回帰平面算出
まずは 3D モデルを書くためのライブラリをインポートします。
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# scikit-learn の Linear Regression を利用します
from sklearn import linear_model
# 回帰用のサンプルデータセットを使います
from sklearn.datasets.samples_generator import make_regression
サンプルデータからの訓練データ生成
サンプルデータがありますので、これを訓練データに基づいた分類をおこなってみましょう。
# 訓練と試験のための合成データを生成します
X, y = make_regression(n_samples=100, n_features=2, n_informative=1,
random_state=0, noise=50)
# =>
# [[ 1.05445173 -1.07075262]
# [-0.36274117 -0.63432209]
# ...
# [-0.17992484 1.17877957]
# [-0.68481009 0.40234164]]
# [ -6.93224214 -4.12640648 29.47265153 -12.03166314 -121.67258636
# -149.24989393 113.53496654 -7.83638906 190.00097568 49.48805247
# ...
# 246.92583786 171.84739934 -33.55917696 38.71008939 -28.23999523
# 39.5677481 -168.02196071 -201.18826919 69.07078178 -36.96534574]
訓練データと試験データの分割
生成したデータを訓練と試験について 80:20 の割合で分割します。
X_train, X_test = X[:80], X[-20:]
y_train, y_test = y[:80], y[-20:]
分類器の訓練
準備ができたら訓練を行います。まずは分類器のインスタンスを生成、次におなじみの .fit メソッドで分類器を訓練します。
regr = linear_model.LinearRegression()
# 訓練する
regr.fit(X_train, y_train)
# 推定値を表示する
print(regr.coef_)
#=> [-10.25691752 90.5463984 ]
値の予測
次に訓練データに基づいた y 値を予測します。
X1 = np.array([1.2, 4])
print(regr.predict(X1))
#=> 350.860363861
評価
結果を評価してみましょう。
print(regr.score(X_test, y_test))
#=> 0.949827492261
可視化
データだけでは直感的ではないので、最後に可視化をしてみます。
fig = plt.figure(figsize=(8, 5))
ax = fig.add_subplot(111, projection='3d')
# ax = Axes3D(fig)
# Data
ax.scatter(X_train[:, 0], X_train[:, 1], y_train, facecolor='#00CC00')
ax.scatter(X_test[:, 0], X_test[:, 1], y_test, facecolor='#FF7800')
coef = regr.coef_
line = lambda x1, x2: coef[0] * x1 + coef[1] * x2
grid_x1, grid_x2 = np.mgrid[-2:2:10j, -2:2:10j]
ax.plot_surface(grid_x1, grid_x2, line(grid_x1, grid_x2),
alpha=0.1, color='k')
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.zaxis.set_visible(False)
fig.savefig('image.png', bbox='tight')
平面が求められました。だいたいあっているようですね。
まとめ
サンプルデータセットを使うと綺麗な平面を描くことができました。現実の問題においてはなかなか綺麗に行かないこともありますが、理論を抑えておくと役に立つでしょう。