概要
あえて重回帰を Chainer でやってみました。説明変数が2個だけの最小構成の重回帰ですが。
重回帰
この記事からデータをお借りしました。ただ、単位を直しました。
いま、下表のようなデータがあるとします。
身長 $x_1$ (m) | ウェスト $x_2$ (m) | 体重 $y$ (100kg) |
---|---|---|
1.65 | 0.65 | 0.50 |
1.70 | 0.68 | 0.60 |
1.72 | 0.70 | 0.65 |
1.75 | 0.65 | 0.65 |
1.70 | 0.80 | 0.70 |
1.72 | 0.85 | 0.75 |
1.83 | 0.78 | 0.80 |
1.87 | 0.79 | 0.85 |
1.80 | 0.95 | 0.90 |
1.85 | 0.97 | 0.95 |
身長 $x_1$ とウェスト $x_2$ を説明変数として、体重 $y$ を予測します。
コード
import numpy as np
import chainer
from chainer import Chain, Variable
import chainer.functions as F
import chainer.links as L
xs = np.array(
[[1.65, 0.65],
[1.70, 0.68],
[1.72, 0.70],
[1.75, 0.65],
[1.70, 0.80],
[1.72, 0.85],
[1.83, 0.78],
[1.87, 0.79],
[1.80, 0.95],
[1.85, 0.97]]
, "f")
ys = np.array(
[[0.50],
[0.60],
[0.65],
[0.65],
[0.70],
[0.75],
[0.80],
[0.85],
[0.90],
[0.95]]
, "f")
class MultiRegression(Chain):
def __init__(self):
super().__init__(
l1=L.Linear(2, 1)
)
def __call__(self, x):
return self.l1(x)
def normal_equation(xs, ys):
xs2 = np.hstack((np.ones((xs.shape[0], 1), "f"), xs))
theta = np.linalg.inv(xs2.T.dot(xs2)).dot(xs2.T).dot(ys)
return theta
def normalize(xs, ys):
xmeans = np.mean(xs, axis = 0)
print("xmeans = " + np.array_str(xmeans))
xstds = np.std(xs, axis = 0)
print("xstds = " + np.array_str(xstds))
ymean = np.mean(ys)
print("ymean = %f" % ymean)
ystd = np.std(ys)
print("ystd = %f" % ystd)
nxs = (xs - xmeans) / xstds
nys = (ys - ymean) / ystd
return nxs, nys, xmeans, xstds, ymean, ystd
def denormalize_params(nW, nb, xmeans, xstds, ymean, ystd):
W = ystd / xstds * nW
b = ymean - (sum(ystd / xstds * xmeans * nW) + ystd * nb)
return W, b
def train(xs, ys):
xs = Variable(xs)
ys = Variable(ys)
model = MultiRegression()
alpha = 0.5
optimizer = chainer.optimizers.SGD(lr = alpha)
optimizer.setup(model)
for i in range(100):
model.l1.zerograds()
yp = model(xs)
loss = F.mean_squared_error(yp, ys)
loss.backward()
print("=== Epoch %d ===" % (i + 1))
print("loss = %f" % loss.data)
print("model.l1.W.data = %s" % np.array_str(model.l1.W.data))
print("model.l1.W.grad = %s" % np.array_str(model.l1.W.grad))
print("model.l1.b.data = %s" % np.array_str(model.l1.b.data))
print("model.l1.b.grad = %s" % np.array_str(model.l1.b.grad))
print("")
optimizer.update()
return model.l1.W.data[0], model.l1.b.data[0]
nxs, nys, xmeans, xstds, ymean, ystd = normalize(xs, ys)
nW, nb = train(nxs, nys)
W, b = denormalize_params(nW, nb, xmeans, xstds, ymean, ystd)
print("result of training W = %s, b = %f" % (np.array_str(W), b))
theta = normal_equation(xs, ys)
print("result of normal equation W = [%f, %f], b = %f" % (theta[1][0], theta[2][0], theta[0][0]))
実行結果
xmeans = [ 1.75900006 0.78199995]
xstds = [ 0.07020684 0.10979983]
ymean = 0.735000
ystd = 0.134257
=== Epoch 1 ===
loss = 0.644343
model.l1.W.data = [[ 0.23022108 0.00259643]]
model.l1.W.grad = [[-1.29441714 -1.51606596]]
model.l1.b.data = [ 0.]
model.l1.b.grad = [ -8.94069672e-08]
=== Epoch 2 ===
loss = 0.234075
model.l1.W.data = [[ 0.87742966 0.76062942]]
model.l1.W.grad = [[ 0.90113986 0.76939297]]
model.l1.b.data = [ 4.47034836e-08]
model.l1.b.grad = [ -1.19209290e-07]
...
=== Epoch 100 ===
loss = 0.009937
model.l1.W.data = [[ 0.53428751 0.57989436]]
model.l1.W.grad = [[ 6.89178705e-08 1.86264515e-08]]
model.l1.b.data = [ -8.42846930e-08]
model.l1.b.grad = [ -5.40167093e-08]
result of training W = [ 1.02172303 0.70906305], b = -1.616698
result of normal equation W = [1.021702, 0.709063], b = -1.616678
解説
予測式は以下のとおりになります(係数は小数点4位で四捨五入)。
y = 1.0217x_1 + 0.7091x_2 - 1.6167
result of training
が最急降下法による推定、result of normal equation
が正規方程式による推定となっています。
感想
徒然なる感想を少し。
重回帰式を chainer.links.Linear
クラスで表現しています。これは Link
クラスから派生したクラスです。
上のコードではごちゃごちゃと正規化(Normalization)をやっているのがわかるかと思います。本当は、複雑になるのでやりたくなかったのですが、これをやらないと学習が遅々として進まない(10万回もループを回す必要がある)ためです。正規化(Normalization)の重要さを嫌というほど思い知らされました。
私は、Python もあまりわかっていない状態なのですが、今回のコードを書くのに、Chainer のソースコードを読み込みました。Variable
クラスと Function
クラスにおける backward
メソッドの挙動の概略を理解できたのは、大きな収穫でした。Chainer のソースコードはとても簡潔で読んでいて楽しいです。
Python では、インデントでコードのブロックが一目瞭然なのはよいですね。この点は、Haml に出会ってから大きく考えが変わりました。Ruby を愛する私ですが、ひょっとしたら(やはり、というべきか)、Ruby より Python のコードの方が読みやすいかもしれない…とか思ってしまいました。