155
159

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

線形回帰の Normal Equation(正規方程式)について

Last updated at Posted at 2015-07-23

前置き

オンライン機械学習コース の Linear Regression with Multiple Variables(多変量線形回帰)で出てきた、Normal Equation(正規方程式)について。

Andrew Ng 先生(以降、Ang先生 と略記)が「導出するのめんどい(意訳)」と言って結果だけ示されたので、ちょっとだけ掘り下げてみました。

その中で、疑問点も浮かんできたので共有してみます。
私自身、まだちゃんと分かってない部分もあるかもなので、ツッコミ大歓迎です。

【2015/07/24 23:10】検証コードを追加し、大幅に加筆修正しています。

Normal Equation(正規方程式)

まずはおさらい。

$X$ は、トレーニングデータの特徴量全体を表す $m \times (n+1)$ 行列($m$(行数)はデータの件数、$n$ は feature(特徴)の数)。
$y$ は、トレーニングデータの「正解の値」を並べた $m$次元ベクトル。
$\theta$ は、Hypothesis Function(仮説関数)のパラメータ($h_{\theta}(x) = \sum_i \theta_i x_i$、$0 \le i \le n$)を表す $(n+1)$次元ベクトル。
さらに $J(\theta) = \frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)})-y^{(i)})^2$ を(線形回帰の)Cost Function(目的関数)1と呼び、これを最小化することが線形回帰の目的。

最小化する方法として、最後の $J(\theta)$ を $\theta_0, \dots , \theta_n$ それぞれで偏微分して、それらが全て 0 となるように $\theta$ を求めていく。例えばある $\theta_j$ で偏微分すると:

$$\begin{eqnarray}
\frac{\partial}{\partial \theta_j}J(\theta) &=& \frac{1}{m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)})-y^{(i)})x_j^{(i)} \\
&=& \frac{1}{m}\{\sum_{k}\theta_k\sum_{i}x_j^{(i)}x_k^{(i)} - \sum_{i}x_j^{(i)}y^{(i)}\}
\end{eqnarray}$$

これを $m$ 行まとめて変形して行列のカタチで書き直すと2

$$
m\left[\begin{array}{cc}
\frac{\partial}{\partial \theta_0}J(\theta)\\
\frac{\partial}{\partial \theta_1}J(\theta)\\
\vdots\\
\frac{\partial}{\partial \theta_n}J(\theta)
\end{array}\right] =
X^TX \theta - X^Ty
$$

これを $= {\bf 0}$ とおいて $\theta$ について解く訳なので:

$$\begin{eqnarray}
X^TX \theta &=& X^Ty\\
\theta &=& (X^TX)^{-1} X^Ty
\end{eqnarray}$$

これが Ang先生 の仰る「Normal Equation(正規方程式)」です。
最後の式の右辺を計算すれば、(二乗誤差が最小となる)最適なパラメータ $\theta$ が求まる、という訳です。

一般化逆行列

$X$ に対して $(X^TX)^{-1}$ がいつも存在するかというと、そうとは限りません。
ただしこれも、Ang 先生は講義で「そうなることは非常に稀だから気にするな」と仰っています。
また万が一そういった場合が発生したとしても、例えば Octave なら、pinv(X'*X)*X'*yとすれば良い、とも説明しています。X' は行列 X の転置行列($X$ に対する $X^T$)です。

Octave の pinv は、"Pseudoinverse"(擬逆行列、一般化逆行列とも言う。以下、「一般化逆行列」を用いる)を返す関数。
引数の行列が可逆なら逆行列(に一致する行列)を返し、そうでなければある性質を満たす「一般化逆行列」を返す仕組みになっています。

一般化逆行列の詳細はここでは省略しますが、簡単に言えば、「逆行列と同じような性質を持つ行列」です。これは正則でない(=可逆でない)正方行列のみならず、正方でない(=$m \times n$ 行列で $m \ne n$ であるような)行列にも定義できちゃいます。

これによって、$X^TX$ が可逆ならば問題なく $(X^TX)^{-1}$ を返して計算するし、そうでなくても適切に算出された「一般化逆行列」によってエラーなく計算され、しかも妥当な計算結果が得られる仕組み、ということです。

さらに、一般化逆行列

先ほども書きましたが、一般化逆行列は $m \times n$ 行列($m \ne n$)にも定義されます。
特に $m > n$ の場合(=行が列より多い場合=縦長の行列の場合)、それは実はこう書けます:

$$ X^- = (X^TX)^{-1}X^T$$

今考えているのは、(多変量の)線形回帰。(一般に)feature(特徴)の数に比べて、データの数は非常に多くなる傾向にあります。つまり、行列は行の方が多い縦長の行列になるはず。

さて、ここで Normal Equation(正規方程式)をもう一度見てみましょう。

$$ \theta = (X^TX)^{-1} X^Ty $$

右辺の $y$ 以外の部分。同じカタチですよね?つまり。こう書いてしまっても良いのではないか?と。

$$ \theta = X^-y $$

Octave のコードで言えば、pinv(X)*y。これで良いんじゃないか?と。

さらに。
ちょっと以下の方程式を考えてみます。

$$ X\theta = y$$

$X$ が正則行列(正方行列で可逆)ならば、$\theta = X^{-1}y$ で求解出来ます。
そうでない場合は、そのような普通の方法は存在しません(そもそも充足解が存在しない、または解が一意に決まらない)。
ところが、逆行列を一般化逆行列に置き換えた $\theta = X^-y$ を考えると、誤差(またはノルム)が最小となる解を求めることができます。つまりこの式は、$X\theta = y$ に対する誤差最小解の求解を意味しているのです。

そして、Octave には $X\theta = y$ をそのまま求解する別の方法が用意されています。バックスラッシュ演算子を利用した X \ y という書き方です。
これは $X$ が正則行列なら $X^{-1}y$ の意味になり、そうでない場合は $X^-y$ と同じような計算をしてくれるようです。

ここで、Normal Equation(正規方程式)の元のカタチを思い出してみましょう。

$$\begin{eqnarray}
X^TX \theta &=& X^Ty\\
(X^TX) \theta &=& (X^Ty)
\end{eqnarray}$$

これも $\bigcirc\theta=\bigcirc$ の形をしていますよね。○ \ ○ とバックスラッシュ演算子で求解出来る形になっている。当てはめてみると、(X'*X) \ (X'*y)

つまり。
Normal Equation(正規方程式)を求解する Octave のコードは、以下の4種類が考えられうる、ということです:

  • pinv(X'*X)*X'*y
  • pinv(X)*y
  • X \ y
  • (X'*X) \ (X'*y)

これらのうちで、講義で紹介されているのは最前者の pinv(X'*X)*X'*y だけです。
なぜ、2番目や3番目・4番目が紹介されていないのでしょう?

検証1

ということで、実際のコードを書いて検証してみました。

solve_X_y.m
% solve_X_y.m
X = rand(10, 4)
# X =
#
#    0.033536   0.816107   0.996677   0.958327
#    0.683542   0.116498   0.614316   0.884338
#    0.734337   0.769245   0.696212   0.245270
#    0.216938   0.013297   0.885327   0.906086
#    0.630620   0.733668   0.820551   0.784664
#    0.138834   0.838178   0.216751   0.638286
#    0.100739   0.893597   0.891867   0.239482
#    0.362333   0.404999   0.018274   0.922847
#    0.102606   0.442110   0.744582   0.452299
#    0.590709   0.274452   0.459526   0.656588

y = rand(10, 1)
# y = 
# 
#    0.48518
#    0.13242
#    0.60525
#    0.31265
#    0.59250
#    0.47161
#    0.95971
#    0.44011
#    0.60115
#    0.75571

# calcuration
# [1]
pinv(X' * X) * X' * y
# ans =
# 
#    0.1861915
#    0.5484641
#    0.2473279
#   -0.0031611

# [2]
pinv(X) * y
# ans =
# 
#    0.1861915
#    0.5484641
#    0.2473279
#   -0.0031611

# [3]
X \ y
# ans =
# 
#    0.1861915
#    0.5484641
#    0.2473279
#   -0.0031611

# [4]
(X'*X) \ (X'*y)
# ans =
# 
#    0.1861915
#    0.5484641
#    0.2473279
#   -0.0031611

# time measurement (n = 10)
# [1]
tic();
for k=1:10000;
    X = rand(40, 10);
    y = rand(40, 1);
    pinv(X' * X) * X' * y;
end;
toc()
# Elapsed time is 1.26513 seconds.

# [2]
tic();
for k=1:10000;
    X = rand(40, 10);
    y = rand(40, 1);
    pinv(X) * y;
end;
toc()
# Elapsed time is 1.16283 seconds.

# [3]
tic();
for k=1:10000;
    X = rand(40, 10);
    y = rand(40, 1);
    X \ y;
end;
toc()
# Elapsed time is 0.902037 seconds.

# [4]
tic();
for k=1:10000;
    X = rand(40, 10);
    y = rand(40, 1);
    (X'*X) \ (X'*y);
end;
toc()
# Elapsed time is 0.689348 seconds.

# time measurement (n = 30)
# [1]
tic();
for k=1:10000;
    X = rand(100, 30);
    y = rand(100, 1);
    pinv(X' * X) * X' * y;
end;
toc()
# Elapsed time is 5.79588 seconds.

# [2]
tic();
for k=1:10000;
    X = rand(100, 30);
    y = rand(100, 1);
    pinv(X) * y;
end;
toc()
# Elapsed time is 7.11547 seconds.

# [3]
tic();
for k=1:10000;
    X = rand(100, 30);
    y = rand(100, 1);
    X \ y;
end;
toc()
# Elapsed time is 3.64188 seconds.

# [4]
tic();
for k=1:10000;
    X = rand(100, 30);
    y = rand(100, 1);
    (X'*X) \ (X'*y);
end;
toc()
# Elapsed time is 1.37039 seconds.

まず、計算結果を見る限りは、[1]、[2]、[3]、[4] いずれも確かに、同じ値を算出していますね3

また時間を計ってみると、$n$ の値(=行列の列数=feature(特徴)の数)が大きくなればなるほど時間がかかるのはまぁ当たり前として、その時 [1] の pinv(X' * X) * X' * y の方が [2] pinv(X) * y よりも実行時間が短いです。つまり、見た目の計算式はより複雑ですが、計算量は前者の方が少ない、ということですね。
これは、$X$ を $m \times n$行列($m > n$)とした場合、$X^TX$ は $n \times n$ の正方行列になり、さらにほぼ可逆になるので、高速に逆行列が計算される(その後 $n \times m$ の $X^T$ をかけるコストも割と低い)のに比較して、そもそも正方行列でない $X$ の一般化逆行列を(しかも $m \times n > n \times n$ だし)計算するコストが高くなっている、ということなのではないかと思います。

しかし、[3] の X \ y の計算時間の方がもっと速いという結果も出ています。
さらに、[4] の (X'*X) \ (X'*y) の方がもっともっと速いですね。
[3] より [4] の方が速いのは、先ほどと同じ理由によるものです。

ちなみにもちろん rand() を利用しているので実行するたびに X および y の値は変わりますが、計算結果および実行時間はほぼ全く同様でした。
このままだと、(X'*X) \ (X'*y) が一番良さそうに思えますけれど…。

検証2

先ほどはランダムな行列を使用したコードで検証してみました。
今度は、意図的に数値を並べた行列で試してみます。

rank_deficient.m
X = [1 1 2 3;
     1 2 3 5;
     1 3 5 8;
     1 5 8 13;
     1 8 13 21;
     1 13 21 34]

y = [8; 13; 21; 34; 55; 89]

# check rank of matrix
rank(X)
# => 3

# calcuration
# [1]
pinv(X'*X) * X'*y
# ans =
# 
#    3.1974e-13
#    3.3333e-01
#    1.3333e+00
#    1.6667e+00

# [2]
pinv(X) * y
# ans =
# 
#    0.00000
#    0.33333
#    1.33333
#    1.66667

# [3]
X \ y
# ans =
# 
#   -1.3628e-14
#    3.3333e-01
#    1.3333e+00
#    1.6667e+00

# [4]
(X'*X) \ (X'*y)
# warning: matrix singular to machine precision, rcond = 4.97057e-18
# ans =
# 
#   -1.8859e-13
#    3.3333e-01
#    1.3333e+00
#    1.6667e+00


# Square Matrix
X = X(1:4, 1:4)
# X =
# 
#     1    1    2    3
#     1    2    3    5
#     1    3    5    8
#     1    5    8   13

y = y(1:4)
# y =
# 
#     8
#    13
#    21
#    34

# calcuration
# [1]
pinv(X'*X) * X'*y
# ans =
# 
#    1.8119e-13
#    3.3333e-01
#    1.3333e+00
#    1.6667e+00

# [2]
pinv(X) * y
# ans =
# 
#   -7.1054e-15
#    3.3333e-01
#    1.3333e+00
#    1.6667e+00

# [3]
X \ y
# warning: matrix singular to machine precision, rcond = 0
# ans =
# 
#   -7.3807e-15
#    3.3333e-01
#    1.3333e+00
#    1.6667e+00

# [4]
(X'*X) \ (X'*y)
# warning: matrix singular to machine precision, rcond = 1.26207e-17
# ans =
# 
#    1.5742e-14
#    3.3333e-01
#    1.3333e+00
#    1.6667e+00

行列 $X$ の、2〜4列目は、連続するフィボナッチ数になっています。
つまり($i$列目を $x_i$ と書くと)$x_2 + x_3 = x_4$ になっています。
$m \times 4$ 行列($m \ge 4$)で、独立な列は3つだけ(4列目は他の列の線形結合で表されるため)なので、行列の階数(rank)は $3$ です4

このとき、$n \times n$ 行列である $X^TX$ は、正則行列ではありません(逆行列を持ちません)。また $m = n$ の場合、$X$ それ自身も正方行列にはなりますがやはり正則行列ではありません。

その結果、$m > n$ の時の [4]((X'*X) \ (X'*y))、および $m = n$ の時の [3], [4](X \ y, (X'*X) \ (X'*y))の計算時に、警告が出力されてしまっています。
\ 演算子による求解は、このような問題があるようです。

一方、一般化逆行列による求解の場合は、警告無く計算が行われている模様です。
初めから「正則行列でなくても(正方行列ですらなくても)妥当な計算をするようになっている」から、なのですね。

結果からの考察

Ang先生 が pinv(X'*X)*X'*y を紹介したのは、(本来の式変形からの)導出結果の式に最も近く、かつ pinv(X)*y よりも高速に計算できるから、ということなのではないかな、と。

また、X \ y(X'*X) \ (X'*y) に触れていないのは、これらの計算が、演算子の左側の行列が正則でない正方行列(もしくは、「ランク落ち」の行列)の場合に問題が起きるから(pinv() を用いた場合はその心配が無いから)ではないか、と思われます5

ただ、これはデータ(feature)の選び方にそもそも問題がある場合、とも言えるので、他の feature に依らず独立な features たちだけを選んでやれば、より簡潔かつ高速な (X'*X) \ (X'*y) を利用するのは悪く無いとも思います。

あとは、そのためのデータ解析に重きを置いて(その作業にコストを払って)計算を高速にする((X'*X) \ (X'*y))か、データには特別手を入れずにそれなりに高速で間違いの無い計算をする(pinv(X'*X)*X'*y)か。その駆け引き、ということになるのでしょう。

おまけ:他言語の場合

この手の(言語や処理系が指定されている)講義や教科書での勉強において、理解を深めるために、個人的に他言語でも記述してみる、ということを私はよくやります。
ということで、今回検証したコードを Julia(v0.3.x/0.4.0-dev) および Python(v2.7.x/3.x)+NumPy でも検証してみました。

Python には X \ y に該当する演算子がない代わりに、numpy.linalg.solve(X, y) という関数があるのですが、X が正方行列でないと動作しない模様です。

Julia 版

Julia v0.3.x/0.4.0 どちらでも動作します6

solve_X_y.jl
X = rand(10, 4)
# 10x4 Array{Float64,2}:
#  0.71148    0.968352  0.0952939  0.796324 
#  0.915128   0.128326  0.630086   0.0635579
#  0.351199   0.131409  0.934867   0.501701 
#  0.165645   0.874088  0.173725   0.976326 
#  0.765261   0.790716  0.760362   0.204496 
#  0.544099   0.156464  0.041718   0.507071 
#  0.764964   0.852837  0.230312   0.134783 
#  0.0738597  0.75529   0.693856   0.0107293
#  0.621861   0.56881   0.66972    0.163911 
#  0.9471     0.453841  0.466836   0.10744  

y = rand(10, 1)
# 10x1 Array{Float64,2}:
#  0.389321 
#  0.436261 
#  0.308189 
#  0.734617 
#  0.410237 
#  0.4969   
#  0.0708882
#  0.0840005
#  0.944711 
#  0.14718  

# calcuration
# [1]
pinv(X' * X) * X' * y
# 4x1 Array{Float64,2}:
#   0.169937 
#  -0.0365938
#   0.273122 
#   0.55004  

# [2]
pinv(X) * y
# 4x1 Array{Float64,2}:
#   0.169937 
#  -0.0365938
#   0.273122 
#   0.55004  

# [3]
X \ y
# 4x1 Array{Float64,2}:
#   0.169937 
#  -0.0365938
#   0.273122 
#   0.55004  

# [4]
(X'*X) \ (X'*y)
# 4x1 Array{Float64,2}:
#   0.169937 
#  -0.0365938
#   0.273122 
#   0.55004  

# time measurement (n = 10)
# [1]
@time for k=1:10000
    X = rand(40, 10)
    y = rand(40, 1)
    pinv(X' * X) * X' * y
end
# elapsed time: 1.087745051 seconds (283600016 bytes allocated, 17.28% gc time)

# [2]
@time for k=1:10000
    X = rand(40, 10)
    y = rand(40, 1)
    pinv(X) * y
end
# elapsed time: 1.278193773 seconds (334800016 bytes allocated, 17.29% gc time)

# [3]
@time for k=1:10000
    X = rand(40, 10)
    y = rand(40, 1)
    X \ y
end
# elapsed time: 1.014968744 seconds (324320000 bytes allocated, 20.29% gc time)

# [4]
@time for k=1:10000
    X = rand(100, 30)
    y = rand(100, 1)
    (X'*X) \ (X'*y)
end
# elapsed time: 0.163586767 seconds (62720032 bytes allocated, 41.51% gc time)

# time measurement (n = 30)
# [1]
@time for k=1:10000
    X = rand(100, 30)
    y = rand(100, 1)
    pinv(X' * X) * X' * y
end
# elapsed time: 5.820615493 seconds (1557840000 bytes allocated, 19.02% gc time)

# [2]
@time for k=1:10000
    X = rand(100, 30)
    y = rand(100, 1)
    pinv(X) * y
end
# elapsed time: 7.518744844 seconds (1914480016 bytes allocated, 16.51% gc time)

# [3]
@time for k=1:10000
    X = rand(100, 30)
    y = rand(100, 1)
    X \ y
end
# elapsed time: 3.455976006 seconds (1292000000 bytes allocated, 22.67% gc time)

# [4]
@time for k=1:10000
    X = rand(100, 30)
    y = rand(100, 1)
    (X'*X) \ (X'*y)
end
# elapsed time: 0.777771618 seconds (407840016 bytes allocated, 32.71% gc time)

思ったほど速くないという印象ですが、直接トップレベルに for を書いているから、だと思います。
関数化したりその他パフォーマンスに気をつけた書き方に変更すればきっともっと速くなるはず。
でも (X'*X) \ (X'*y) はやっぱり他を圧倒するくらい速いですね! でも…

rank_deficient.jl
X = [1 1 2 3;
     1 2 3 5;
     1 3 5 8;
     1 5 8 13;
     1 8 13 21;
     1 13 21 34]

y = [8; 13; 21; 34; 55; 89]

# check rank of matrix
rank(X)
# => 3

# calcuration
# [1]
pinv(X'*X) * X'*y
# 4-element Array{Float64,1}:
#  -7.10543e-15
#   0.333333   
#   1.33333    
#   1.66667    

# [2]
pinv(X) * y
# 4-element Array{Float64,1}:
#   3.55271e-15
#   0.333333   
#   1.33333    
#   1.66667    

# [3]
X \ y
# 4-element Array{Float64,1}:
#  -4.35117e-15
#   2.0        
#   3.0        
#   0.0        

# [4]
(X'*X) \ (X'*y)
# 4-element Array{Float64,1}:
#   3.22113e-13
#  -1.50024    
#  -0.500244   
#   3.50024    


# Square Matrix
X = X[1:4, 1:4]
# 4x4 Array{Int64,2}:
#  1  1  2   3
#  1  2  3   5
#  1  3  5   8
#  1  5  8  13

y = y[1:4]
# 4-element Array{Int64,1}:
#   8
#  13
#  21
#  34

# calcuration
# [1]
pinv(X'*X) * X'*y
# 4-element Array{Float64,1}:
#  8.52651e-14
#  0.333333   
#  1.33333    
#  1.66667    

# [2]
pinv(X) * y
# 4-element Array{Float64,1}:
#  3.55271e-15
#  0.333333   
#  1.33333    
#  1.66667    

# x[3]
# X \ y
# @> SingularException(4)

# x[4]
# (X'*X) \ (X'*y)
# @> SingularException(4)

ランク落ちの場合。
縦長行列の時は [3], [4](X \ y, (X'*X) \ (X'*y))が pinv() を使用したものと大きく異なる計算結果となっています。
また正方行列の場合は、「正則行列でない」と言う旨のエラーが出てしまい計算もできませn(>_<)
(ちなみに Julia v0.3.x と v0.4.0 とでも結果が少し異なっていたという点も軽く触れておきます)

Python + NumPy 版

同じく Python v2.7.x/3.x どちらでも動作、結果も同様:

solve_X_y.py
import numpy as np
X = np.random.rand(10, 4)
# array([[ 0.61009055,  0.71722947,  0.48465025,  0.15660522],
#        [ 0.02424431,  0.49947237,  0.60493258,  0.8988653 ],
#        [ 0.65048106,  0.69667863,  0.52860957,  0.65003537],
#        [ 0.56541266,  0.25463788,  0.74047536,  0.64691215],
#        [ 0.03052439,  0.47651739,  0.01667898,  0.7613639 ],
#        [ 0.87725831,  0.47684888,  0.44039111,  0.39706053],
#        [ 0.58302851,  0.20919564,  0.97598994,  0.19268083],
#        [ 0.35987338,  0.98331404,  0.06299533,  0.76193058],
#        [ 0.625453  ,  0.70985323,  0.62948802,  0.627458  ],
#        [ 0.64201569,  0.22264827,  0.71333221,  0.53305839]])

y = np.random.rand(10, 1)
# array([[ 0.99674247],
#        [ 0.66282312],
#        [ 0.68295932],
#        [ 0.14330449],
#        [ 0.17467666],
#        [ 0.90896029],
#        [ 0.65385071],
#        [ 0.00748736],
#        [ 0.93824979],
#        [ 0.91696375]])

# calcuration
# [1]
np.linalg.pinv(X.T.dot(X)).dot(X.T).dot(y)
# array([[ 0.32591078],
#        [ 0.46479763],
#        [ 0.6684976 ],
#        [-0.26695783]])

# [2]
np.linalg.pinv(X).dot(y)
# array([[ 0.32591078],
#        [ 0.46479763],
#        [ 0.6684976 ],
#        [-0.26695783]])

# x[3]
# np.linalg.solve(X, y)
# @> LinAlgError

# [4]
np.linalg.solve(X.T.dot(X), X.T.dot(y))
# array([[ 0.32591078],
#        [ 0.46479763],
#        [ 0.6684976 ],
#        [-0.26695783]])

# time measurement (n = 10)
from timeit import timeit
# [1]
def test_a():
    X = np.random.rand(40, 10)
    y = np.random.rand(40, 1)
    np.linalg.pinv(X.T.dot(X)).dot(X.T).dot(y)

timeit("test_a()", setup="from __main__ import test_a", number=10000)
# 1.1948060989379883

# [2]
def test_b():
    X = np.random.rand(40, 10)
    y = np.random.rand(40, 1)
    np.linalg.pinv(X).dot(y)

timeit("test_b()", setup="from __main__ import test_b", number=10000)
# 1.2698009014129639

# [4]
def test_c():
    X = np.random.rand(40, 10)
    y = np.random.rand(40, 1)
    np.linalg.solve(X.T.dot(X), X.T.dot(y))

timeit("test_c()", setup="from __main__ import test_c", number=10000)
# 0.4645709991455078

# time measurement (n = 30)
# [1]
def test_d():
    X = np.random.rand(100, 30)
    y = np.random.rand(100, 1)
    np.linalg.pinv(X.T.dot(X)).dot(X.T).dot(y)

timeit("test_d()", setup="from __main__ import test_d", number=10000)
# 4.615994930267334

# [2]
def test_e():
    X = np.random.rand(100, 30)
    y = np.random.rand(100, 1)
    np.linalg.pinv(X).dot(y)

timeit("test_e()", setup="from __main__ import test_e", number=10000)
# 5.413921117782593

# [4]
def test_f():
    X = np.random.rand(100, 30)
    y = np.random.rand(100, 1)
    np.linalg.solve(X.T.dot(X), X.T.dot(y))

timeit("test_f()", setup="from __main__ import test_f", number=10000)
# 0.9642360210418701

NumPy の ndarray は * が要素ごとの積になっており、行列の積は .dot() メソッドを利用しなければなりません。ただでさえ煩雑な式がさらに煩雑に…でも高速! solve() も高速! でも…。

rank_deficient.jl
import numpy as np
X = np.array([
        [1, 1, 2, 3],
        [1, 2, 3, 5],
        [1, 3, 5, 8],
        [1, 5, 8, 13],
        [1, 8, 13, 21],
        [1, 13, 21, 34]])

y = np.array([[8], [13], [21], [34], [55], [89]])

# check rank of matrix
np.linalg.matrix_rank(X)
# => 3

# calcuration
# [1]
np.linalg.pinv(X.T.dot(X)).dot(X.T.dot(y))
# array([[  2.27373675e-13],
#        [  3.33333333e-01],
#        [  1.33333333e+00],
#        [  1.66666667e+00]])

# [2]
np.linalg.pinv(X).dot(y)
# array([[  3.55271368e-15],
#        [  3.33333333e-01],
#        [  1.33333333e+00],
#        [  1.66666667e+00]])

# [4]
np.linalg.solve(X.T.dot(X), X.T.dot(y))
# array([[ -8.12048841e-14],
#        [ -5.00000000e+00],
#        [ -4.00000000e+00],
#        [  7.00000000e+00]])

# Square Matrix
X = X[0:4]
# array([[ 1,  1,  2,  3],
#        [ 1,  2,  3,  5],
#        [ 1,  3,  5,  8],
#        [ 1,  5,  8, 13]])

y = y[0:4]
# array([[ 8],
#        [13],
#        [21],
#        [34]])

# calcuration
# [1]
np.linalg.pinv(X.T.dot(X)).dot(X.T.dot(y))
# array([[ -1.13686838e-13],
#        [  3.33333333e-01],
#        [  1.33333333e+00],
#        [  1.66666667e+00]])

# [2]
np.linalg.pinv(X).dot(y)
# array([[  4.44089210e-15],
#        [  3.33333333e-01],
#        [  1.33333333e+00],
#        [  1.66666667e+00]])

# [4]
np.linalg.solve(X.T.dot(X), X.T.dot(y))
# array([[ -1.47008842e-14],
#        [  1.00000000e+00],
#        [  2.00000000e+00],
#        [  1.00000000e+00]])

ランク落ちの場合。
NumPy の場合、solve() を使用した場合でもエラーは出ていませんが、計算結果は pinv() を使用したものとやはり大きく異なるものとなってしまいました。

  1. この式、所謂二乗誤差関数ですね。つまり最小二乗法です。

  2. スミマセン、ここ端折りました。いちいち書くのめんどくさい(^-^;

  3. さらに詳しく検証すると、実際には 1e-14 くらいの誤差で互いに異なる値が算出されている模様です。計算過程・計算方法が異なるための浮動小数点の演算誤差、ですね。

  4. m行n列の行列 X で、rank(X) = min(m, n) の時、X は「フルランク (full rank)」と言います。rank(X) < min(m, n) の場合は、「ランク落ち (rank deficient)」と言います。

  5. 記事修正前に「問題が無ければ X \ y の方が良くない?どうでしょう?」と尋ねましたが、たぶんこの件が一番の問題なのでしょう。と、自己解決してしまったのですが、どうでしょう?

  6. Julia の行列/ベクトルの扱い方や書式は、MATLAB/Octave から取り入れたものなので、ほとんど書き換えることなくそのまま動作しました。

155
159
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
155
159

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?