1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

scikit-learnを使ってみる(2)

Last updated at Posted at 2021-02-20

前の記事「scikit-learnを使ってみる」で三角関数の学習がうまくできなかった。トレーニングデータではそこそこ学習できても、テストデータの結果はボロボロだった。
今回は方法を変更して、sin関数で作成した連続した3つの値から次の値を学習させていく。下図で説明すると、緑の点3つから赤の点を学習させることになる。これを順次行う。
sin.png

三角関数(2)

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt

look_back = 3


def make_data(y, look_back):
    X = []
    Y = []
    for i in np.arange(len(y) - look_back):
        X.append(y[i:i+look_back])
        Y.append(y[i+look_back])
    return X, Y


# 三角関数 Y = sin(X)を求める。
x = np.linspace(0, 20, num=200)
y = np.sin(x)
X, Y = make_data(y, look_back=look_back)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, shuffle=False)

lr = LinearRegression()
lr.fit(X_train, Y_train)

print(f'回帰係数={lr.coef_}')
print(f'切片={lr.intercept_}')
print(f'決定係数={lr.score(X_train, Y_train)}')

plt.scatter(x, y, label='data')
plt.plot(x[look_back:look_back+len(X_train)],
         lr.predict(X_train), color='green', label='train')
plt.plot(x[look_back+len(X_train):],
         lr.predict(X_test), color='red', label='test')
plt.legend()
plt.savefig('linear_regression6.png')
plt.show()

さて結果は...

linear_regression6.png

うまく学習できたようだ。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?