Help us understand the problem. What is going on with this article?

Pythonで重回帰分析

TL; DR

過去記事で重回帰分析についてやりたいと書いていた割に書かなかったなと思ったので書こうと思います。

重回帰分析について

重回帰分析は、単回帰分析と異なり、説明変数が1次元ではなく、多次元のベクトルになったものです。

公式の導出

説明変数をD次元ベクトル
$$\boldsymbol{x}^{T} = (x_1, x_2, ..., x_D)$$
単回帰分析における切片におけるバイアス項を
$$w_0$$
説明変数の各次元に対する重み係数を
$$\boldsymbol{w}^T = (w_1, w_2, ..., w_D)$$
と置きます。
便宜上、
$$x_0 = 1$$
と定義し、目的変数をyとおくと、バイアス項と重みベクトルw、説明変数xをあわせて

y = w_0x_0 + x_1x_1 + w_2x_2 + ... + w_Dx_D \\
y = \boldsymbol{w}^T\boldsymbol{x}

と書くことができる。
N個のサンプルについて同様に並べると

y_1 = \boldsymbol{x_{1}}^T\boldsymbol{w} \\
y_2 = \boldsymbol{x_{2}}^T\boldsymbol{w} \\
\vdots \\
y_N = \boldsymbol{x_{N}}^T\boldsymbol{w} \\

と書くことができるので、これを行列としてまとめると

\boldsymbol{y} = X\boldsymbol{w} \\
\\ ただし \\
\boldsymbol{y} = (y_1, y_2, ..., y_N)^T \\
X = \left(
\begin{array}{ccccc}
\boldsymbol{x_1}^T\\
\boldsymbol{x_2}^T  \\
\vdots  \\
\boldsymbol{x_N}^T 
\end{array}
\right)

このとき、目的変数と回帰した超平面の二乗誤差Eを最小にする重みwを求めます。

\begin{align}
E &= \sum_{k=1}^{N}(y_k - \boldsymbol{x_k}^T\boldsymbol{w})^2 \\
&= (\boldsymbol{y} - X\boldsymbol{w})^T(\boldsymbol{y} - X\boldsymbol{w})\\
&= \boldsymbol{y}^T\boldsymbol{y} - \boldsymbol{y}^TX\boldsymbol{w} - (X\boldsymbol{w})^T\boldsymbol{y} + (X\boldsymbol{w})^TX\boldsymbol{w} \\
&= \boldsymbol{y}^T\boldsymbol{y} - \boldsymbol{y}^TX\boldsymbol{w} - \boldsymbol{w}^TX^T\boldsymbol{y} + \boldsymbol{w}^TX^TX\boldsymbol{w} \\
&(\boldsymbol{w}の最小値を求めたいので、\boldsymbol{w}についてまとめる)\\
&= \boldsymbol{y}^T\boldsymbol{y} - \boldsymbol{w}^TX^T\boldsymbol{y} - \boldsymbol{w}^TX^T\boldsymbol{y} + \boldsymbol{w}^TX^TX\boldsymbol{w} \\
& = \boldsymbol{y}^T\boldsymbol{y} - 2\boldsymbol{w}^TX^T\boldsymbol{y} + \boldsymbol{w}^TX^TX\boldsymbol{w}
\end{align}

二次形式とベクトルの微分の公式は以下のようになる(証明略)

\frac{∂\boldsymbol{w}^T\boldsymbol{x}}{∂\boldsymbol{w}} = \boldsymbol{x} \\
\frac{∂\boldsymbol{w}^TX\boldsymbol{w}}{∂\boldsymbol{w}} = (X + X^T)\boldsymbol{w}

二乗誤差Eをwについて微分すると

\begin{align}
\frac{∂E}{∂\boldsymbol{w}} &= -2X^T\boldsymbol{y} + (XX^T + X^TX)\boldsymbol{w} \\
& XX^Tは対称行列であるので \\
&= -2X^T\boldsymbol{y} + 2X^TX\boldsymbol{w} \\
&ととける。そこで、 \\
\frac{∂E}{∂\boldsymbol{w}} &= 0 になるのは \\
X^T\boldsymbol{y} &= X^TX\boldsymbol{w} \\
&のときであり、重み係数\boldsymbol{w}の解は \\
\boldsymbol{w} &= (X^TX)^{-1}X^T\boldsymbol{y}
\end{align}

線形代数と微分はすごい、前に書いた線形回帰の導出もこれで楽々でした。まあ一回書き下すのは無駄ではなかったと思っていますが。ただ、これは当然のこととして、サンプル間の強い相関によるrank落ち(多重共線性)などによって逆行列を持たないこともあるため、必ずこの式で求められるわけではないことに注意が必要です。

使用するデータセット

今回も、ワインのデータセットについて試してみようと思います。
アルコール濃度以外のデータを使って、アルコール濃度を回帰で求めてみます。

import pandas as pd
from sklearn import datasets 

# load datasets
wine = datasets.load_wine()

# データフレームに変換
wine_df = pd.DataFrame(wine.data, columns=wine.feature_names)

# alcoholを除いたデータセットと、アルコールだけのデータセットを作成
wine_alcohol = wine_df['alcohol']
wine_df = wine_df.drop('alcohol', axis=1)

これで準備OKです。

scikit-learnによる実装

重回帰分析は、単回帰分析の時と同様に、LinearRegressorクラスを使うことで実装できます。
また、attributeとして

attribute contents
coef_ 重み係数
intercept_ バイアス項

を持ちます。
さて、実装してみます。

from sklearn.linear_model import LinearRegression

regressor = LinearRegression()
regressor.fit(wine_df, wine_alcohol)
print("score: ", regressor.score(wine_df, wine_alcohol))
print("intercept_: ", regressor.intercept_)
print("coef_: ", "\n", regressor.coef_)
print("predict:", "\n", regressor.predict(wine_df))

結果は以下のようになりました。predictはそのまま表示すると長いので折りたたんでます。0.6くらいなので精度としてはまあまあといったところでしょうか?

score:  0.5935573146395273
intercept_:  11.071849541591947
coef_:  
 [ 1.31636223e-01  1.37853612e-01 -3.77877101e-02  4.17911054e-06
  5.20835243e-02  9.12514513e-03 -2.07795701e-01 -1.52497193e-01
  1.63034871e-01  2.16879740e-01  1.60796319e-01  1.01585935e-03]


predictの結果
predict:
[13.66661305 13.64752402 13.59563623 14.4052741 12.97482915 14.15701754
13.85663207 13.89779774 13.4709283 13.84297084 14.02101275 13.61803463
13.82154853 13.63461237 14.42345393 14.18447268 13.71908271 13.60595157
14.70283505 13.57710193 13.35817881 13.33383144 13.41843559 13.32366916
13.04824697 12.79884478 13.65533596 13.45027096 13.29175806 13.41725635
13.48293341 14.09284127 13.11146165 13.66649247 13.30414209 13.24784155
13.30314255 13.29265567 13.18334915 13.63423042 13.32684749 13.458762
13.77941017 13.15429841 13.21852136 13.82152627 13.72213254 13.53142533
13.46094036 14.2264349 13.85794468 13.76673994 14.09331536 14.03032453
13.59846737 13.48538009 13.58934323 13.89144957 13.90264255 12.33551952
12.61359476 12.67111418 12.69102457 12.56493353 12.27242088 12.72326328
12.78345378 12.74724228 12.65455758 12.49852547 12.45526705 12.31592651
12.34293859 12.76645457 12.79760956 12.49715174 12.47441694 12.50076878
12.45205875 12.47810083 12.10828489 12.83470347 12.20214895 12.58767798
12.28763989 12.33634852 12.07605997 12.30388937 12.48399939 12.17067429
12.15083621 12.20561124 12.06011357 12.28848566 12.51293363 12.49238823
12.35407668 12.36653661 12.74888564 12.31812079 12.79821893 12.27456753
12.44095121 12.02631475 12.59972888 12.11742979 12.45221563 12.12000559
11.96407592 12.45100466 12.33021851 12.10957692 12.80162616 12.11439632
12.23441205 12.05074235 12.13838057 11.93515166 12.49734291 12.31230046
12.68034483 12.78528655 12.22540335 12.57538188 12.4551593 12.28688219
12.20360252 11.96657558 11.7762624 12.56485122 12.4734682 12.7509841
12.54745097 12.78985254 12.63279097 13.09236227 12.82015943 12.80478258
12.93068582 12.74768578 12.77899086 13.13475719 12.63590754 12.95336025
13.70179728 12.7361745 12.71471967 13.37247996 13.25055405 13.36895717
13.0352707 13.37067199 12.50773895 13.57637831 12.79881705 13.369749
13.31336471 13.15447604 13.58553634 13.44927794 13.12760049 13.12063534
12.86415386 12.97903986 13.36410603 12.95310457 13.71603328 13.75424839
13.25314658 13.49374775 12.80027482 13.32119359 13.46345337 13.62814343
13.21066122 13.87543436 13.50377703 13.31372158]

自分で実装

scikit-learnライクのAPIを継承するため、scikit-learnのbase.BaseEstimatorとbase.RegressorMixinを継承して自分のクラスを作ります。RegressorMixinを継承すると、coef_やintercept_アトリビュート、また、scoreメソッドなどを継承することができます。

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_is_fitted

class MyLinearRegression(BaseEstimator, RegressorMixin):
    def fit(self, X, y):
        # バイアス項用にx_0 = 1を代入
        X = np.insert(X, 0, 1, axis=1)
        # 解いた式に従って重み係数を算出
        w = (np.linalg.inv(X.T @ X) @ X.T) @ y
        # バイアス項と重み係数に分解
        self.coef_ = w[1:]
        self.intercept_ = w[0]
        return self

    def predict(self, X):
        # fitされてるかチェックしてくれるらしい
        check_is_fitted(self, 'coef_')
        # 予測を返す
        return X @ self.coef_ + self.intercept_

# 実行 pandasデータフレームの状態だと謎にエラーを吐かれるのでvaluesでnumpy arrayに変換している
my_regressor = MyLinearRegression()
my_regressor.fit(wine_df.values, wine_alcohol.values)
print("score: ", my_regressor.score(wine_df.values, wine_alcohol.values))
print("intercept_: ", my_regressor.intercept_)
print("coef_: ", "\n", my_regressor.coef_)
print("predict:", "\n", my_regressor.predict(wine_df.values))

完全に一致したので、自前で実装できた。

score:  0.5935573146395273
intercept_:  11.071849541591947
coef_:  
 [ 1.31636223e-01  1.37853612e-01 -3.77877101e-02  4.17911054e-06
  5.20835243e-02  9.12514513e-03 -2.07795701e-01 -1.52497193e-01
  1.63034871e-01  2.16879740e-01  1.60796319e-01  1.01585935e-03]


predictの結果
predict:
[13.66661305 13.64752402 13.59563623 14.4052741 12.97482915 14.15701754
13.85663207 13.89779774 13.4709283 13.84297084 14.02101275 13.61803463
13.82154853 13.63461237 14.42345393 14.18447268 13.71908271 13.60595157
14.70283505 13.57710193 13.35817881 13.33383144 13.41843559 13.32366916
13.04824697 12.79884478 13.65533596 13.45027096 13.29175806 13.41725635
13.48293341 14.09284127 13.11146165 13.66649247 13.30414209 13.24784155
13.30314255 13.29265567 13.18334915 13.63423042 13.32684749 13.458762
13.77941017 13.15429841 13.21852136 13.82152627 13.72213254 13.53142533
13.46094036 14.2264349 13.85794468 13.76673994 14.09331536 14.03032453
13.59846737 13.48538009 13.58934323 13.89144957 13.90264255 12.33551952
12.61359476 12.67111418 12.69102457 12.56493353 12.27242088 12.72326328
12.78345378 12.74724228 12.65455758 12.49852547 12.45526705 12.31592651
12.34293859 12.76645457 12.79760956 12.49715174 12.47441694 12.50076878
12.45205875 12.47810083 12.10828489 12.83470347 12.20214895 12.58767798
12.28763989 12.33634852 12.07605997 12.30388937 12.48399939 12.17067429
12.15083621 12.20561124 12.06011357 12.28848566 12.51293363 12.49238823
12.35407668 12.36653661 12.74888564 12.31812079 12.79821893 12.27456753
12.44095121 12.02631475 12.59972888 12.11742979 12.45221563 12.12000559
11.96407592 12.45100466 12.33021851 12.10957692 12.80162616 12.11439632
12.23441205 12.05074235 12.13838057 11.93515166 12.49734291 12.31230046
12.68034483 12.78528655 12.22540335 12.57538188 12.4551593 12.28688219
12.20360252 11.96657558 11.7762624 12.56485122 12.4734682 12.7509841
12.54745097 12.78985254 12.63279097 13.09236227 12.82015943 12.80478258
12.93068582 12.74768578 12.77899086 13.13475719 12.63590754 12.95336025
13.70179728 12.7361745 12.71471967 13.37247996 13.25055405 13.36895717
13.0352707 13.37067199 12.50773895 13.57637831 12.79881705 13.369749
13.31336471 13.15447604 13.58553634 13.44927794 13.12760049 13.12063534
12.86415386 12.97903986 13.36410603 12.95310457 13.71603328 13.75424839
13.25314658 13.49374775 12.80027482 13.32119359 13.46345337 13.62814343
13.21066122 13.87543436 13.50377703 13.31372158]

まとめ

今回の実装に当たって、scikit-learn本体の実装もみたけど色々前処理がされてたりと、やはり本家はすごいなあと。そういえば二乗誤差をとるならnumpy.linalg.lstsqを使えば良いという知見を得ました。次はPCAあたりですかね。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away