0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

scikit_learn線形回帰

Posted at

線形回帰(単回帰)とは

線形回帰とは、説明変数xと目的変数 yの間に線形関係(y = αx + β)があると仮定し、実測値と予測値の誤差の2乗平均が最小となるような傾き α と切片 β を求める手法。

wimage.png

実測値 ( x1, x2, x3, ..., xn) の各値について、その予測値 ( y1, y2, y3, ..., yn)との距離の2乗の平均値(MSE)を最小とするような直線を求める。

\begin{align}
MSE&=\frac{(y_{1}-x_{1})^2 + (y_{2}-x_{2})^2 + (y_{3}-x_{3})^2 + \cdots + (y_{n}-x_{n})^2}{N} \\
&= \frac{1}{N}\sum_{n=1}^{N}(y_{n}-x_{n})^2\\
&= \frac{1}{N}\sum_{n=1}^{N}((αx_{n}+β)-x_{n})^2
\end{align}

scikit_learnとは

scikit-learnとは、回帰・分類・クラスタリングなどの機械学習や、データの前処理、モデルの評価などの機能を提供するPythonライブラリのこと。

線形回帰の流れ

wimage.png

ボストン住宅価格のデータセットを例に使って、住宅価格の線形回帰モデルを作成する。

0.データの取得

データは以下のGitHubへのリンクからCSVファイルをダウンロードして使う。

ボストン住宅価格のデータ(GitHub)にリンク

CSVファイルを同一フォルダ内に配置する。必要なモジュールをインポートし、CSVファイルからデータを読み込みDataFrameを作成する。

from sklearn import linear_model
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

df = pd.read_csv("BostonHousing.csv")
df.head()

実行すると、このようなdfが出力される。
スクリーンショット 2025-11-03 8.26.25.png

列名 内容
crim 人口一人当たりの犯罪数
zn 25,000フィード以上の住居価格の占める割合
indus 小売業以外の商業が占める割合
chas チャールズ川の周辺
nox NOx濃度
rm 平均部屋数
age 1940年より前に建てられた物件の割合
dis 5つのボストン市の雇用施設からの距離
rad 感情高速道路へのアクセスのしやすさ
tax $10,000あたりの不動産税率の総計
ptratio 町ごとの児童と教師の比率
b 町ごとの黒人の比率
lstat 給与の低い職業に従事する人口の割合
medv 価格

1. データの前処理

基本統計量の確認
df.describe()

スクリーンショット 2025-11-03 8.36.21.png

欠損値の確認
df.isnull().sum().any()

False

目的変数のヒストグラム作成

目的変数(medv:価格)のヒストグラムを表示することで、外れ値の存在を確認できる

plt.hist(df["medv"],bins=100)
plt.show()

image.png

2. 説明変数の選択

目的変数(medv:価格)を予想するための説明変数に適した列を選択するために、相関係数や散布図から値の相関関係を調べる。

相関係数の取得

DataFrameの各列間の相関係数を取得する。

df.corr()

スクリーンショット 2025-11-03 8.41.56.png

相関係数は絶対値が 0.7以上(0.7以上または-0.7以下)だと相関が強いとされる。

medv(価格)カラムについて、相関係数の絶対値を降順にソートする。

np.abs(df.corr()["medv"]).sort_values(ascending=False)

実行結果

medv       1.000000
lstat      0.737663
rm         0.695360
ptratio    0.507787
indus      0.483725
tax        0.468536
nox        0.427321
crim       0.388305
rad        0.381626
age        0.376955
zn         0.360445
b          0.333461
dis        0.249929
chas       0.175260
Name: medv, dtype: float64

ここからlstat(給与の低い職業に従事する人口の割合)とrm(平均部屋数)が説明変数の候補として考えられる。

相関係数の注意点
  • 直線的な関連のみを表す指標なので、非直線的な関連がある場合は適切な結果を得ることができない
  • 極端な外れ値の存在が結果に影響を与える場合がある
  • 相関があるからといって因果関係があるとは限らない

適切な説明変数を選択するためには、相関係数だけなく散布図などでデータの可視化を行う必要がある。

ヒートマップ
import seaborn as sns

sns.heatmap(df.corr(),square=True,vmax=1,vmin=-1,center=0)
plt.show()

image.png

散布図

散布図行列の作成

pd.plotting.scatter_matrix(df,figsize=(15,15))
plt.show()

image.png

rmmedvの散布図

df.plot.scatter(x="rm",y="medv",figsize=(5,5))
plt.show()

image.png

lstatmedvの散布図

df.plot.scatter(x="lstat",y="medv",figsize=(5,5))
plt.show()

image.png

相関係数ではlstatの方が高いが、散布図で見るとlstatは曲線的な関係に近いことがわかる。よって説明変数として選ぶべき列はrmと考えられる。

3. モデルの作成

説明変数を選択したら、それを訓練用データとして学習させて線形回帰モデルを作成する。線形回帰モデルはsklearnモジュール内のLinearRegression()クラスのオブジェクトを使って作成する。

LinearRegressionオブジェクトの生成
model = linear_model.LinearRegression()

LinearRegressionクラスの属性

  • coef_ : 回帰直線の傾き(回帰変数)
  • intercept_ : 回帰直線の切片
訓練用データの準備

次に訓練用データ(説明変数・目的変数の2つ)を準備する。

説明変数のデータをn行1列のNumpy配列として用意する。

rm_train = np.array(df["rm"]).reshape(-1,1)

目的変数のデータを1次元Numpy配列として用意する。

medv_train = np.array(df["medv"])
学習の実行

fit()メソッドに以下の引数を渡してモデルを完成させる。

  • 第1引数 : 説明変数のデータ(n行1列のNumpy配列)
  • 第2引数 : 目的変数のデータ(1次元Numpy配列)
model.fit(rm_train,medv_train)

これを実行したのち、以下のように回帰直線の傾きと切片を出力してみる。

print(f"傾き:{model.coef_}, 切片:{model.intercept_}")

傾き:[9.10210898], 切片:-34.67062077643851

LinearRegressionオブジェクトが訓練用データを学習し、回帰直線 y = α x + βαβを求めてくれたので、モデルが完成した。

4. モデルのテスト

完成したモデルを使って、説明変数から目的変数を予想して、その結果から妥当性を判断する。

テスト用データの準備

説明変数rmカラムの最小値と最大値から、テスト用データをn行1列のNumpy配列で用意する。

rm_test = np.arange(rm_train.min(),rm_train.max(),0.1).reshape(-1,1)
モデルから予想値を取得

LinearRegressionクラスのpredict()メソッドを使って、テスト用データから目的変数を予想する。
predict()の引数にはn行1列のNumpy配列を渡す。
戻り値は1次元Numpy配列になる。

medv_test = model.predict(rm_test)
妥当性の評価

モデルを使って取得した予想結果"medv_test"のグラフと、説明変数と目的変数の散布図を重ねてモデルの妥当性を評価する。

plt.plot(rm_test,medv_test,c="r")
plt.scatter(rm_train,medv_train)

plt.title("BostonHousingPrice")
plt.xlabel("rooms")
plt.ylabel("price")
plt.show()

image.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?