Julia 0.6.2でTensorFlowを使う。その3。
その1
https://qiita.com/cometscome_phys/items/358bc4a1feaec1c7fa14
その2:線形回帰をやってみる
https://qiita.com/cometscome_phys/items/244cfed8ab309156735c
今回は過学習について。
#多項式フィッティング
その2では線形回帰:
$$
y = ax+b
$$
を用いてデータをフィッティングした。
世の中の様々な関数はもちろん線形なものとは限らないので、関数をもう少し高級なものにしてみたい。すぐに思いつくのは、
$$
y = a_1 x + a_2 x^2 + b
$$
や
$$
y = a_1 x + a_2 x^2 + a_3 x^3 + b
$$
などの関数だろう。多項式を使う場合には、一般的には
$$
y = \sum_{i=1}^n a_i x^i + b
$$
としておけば、$n$次の多項式でのフィッティングができるはずである。
オリジナルデータ
今回のオリジナルデータは、
$$
y = a_1 x + a_2 x^2 + b + {\rm noise}
$$
という2次関数としておこう。$-2 \le x \le 2$の範囲内で、この関数は
using Plots
ENV["PLOTS_TEST"] = "true"
gr()
n = 10
x0 = linspace(-2,2,n)
totalnoise = 3*rand(linspace(-1,1),2n)
noise = totalnoise[1:n]
a0 = 3.0
a1= 2.0
b0 = 1.0
y0 = a0.*x0 + a1.*x0.^2 + b0 + noise
pl=plot(x0,y0,marker=:circle,label="Data")
savefig("data2.png")
pl
となっている。
その2と同じようにグラフを設計してみよう。
2次関数の場合
using TensorFlow
x = placeholder(Float64)
yout = placeholder(Float64)
a1 = Variable(2.0)
a2 = Variable(2.0)
a3 = Variable(2.0)
a4 = Variable(2.0)
b = Variable(0.0)
y = a1.*x+a2.*x.^2 + b
diff = y-yout
loss = nn.l2_loss(diff)
optimizer = train.AdamOptimizer()
minimize = train.minimize(optimizer, loss)
これを使ってグラフを計算してみる。
sess = Session()
run(sess, global_variables_initializer())
nt = 20000
for i in 1:nt
run(sess, minimize, Dict(x=>x0,yout=>y0))
if i%1000==0
println(i,"\t",run(sess, loss, Dict(x=>x0,yout=>y0)))
end
end
ae1 = run(sess, a1)
ae2 = run(sess, a2)
be = run(sess, b)
ye = run(sess, y, Dict(x=>x0,yout=>y0))
println("a1 = ", ae1," a2 = ", ae2)
println(" b = ", be)
close(sess)
ple = plot(x0,[y0,ye],marker=:circle,label=["Data","Estimation"])
savefig("comparison2.png")
ple
その結果は
1000 12.846654779416948
2000 11.229496095926173
3000 11.17696202312561
4000 11.17684298519887
5000 11.176842982090793
6000 11.176842982090797
7000 11.176842982090797
8000 11.176842982090793
9000 11.176842982090797
10000 11.176842982090795
11000 11.176842982090793
12000 11.176842982090797
13000 11.17684298874014
14000 11.176842983351635
15000 11.176842982090907
16000 11.176842982090793
17000 11.176842982090793
18000 11.176842982090793
19000 11.176842982090797
20000 11.17684298209079
となる。まあ悪くない気がする。
高次の多項式の場合
次に、元のデータが二次関数であると知らなかったとしよう。その場合、このデータはまだまだ良いフィッティング関数でフィッティングできるのではないか?と考えるかもしれない。
ということで、3次関数でフィッティングしてみると、
using TensorFlow
x = placeholder(Float64)
yout = placeholder(Float64)
a1 = Variable(2.0)
a2 = Variable(2.0)
a3 = Variable(2.0)
a4 = Variable(2.0)
b = Variable(0.0)
y = a1.*x+a2.*x.^2 +a3.*x.^3 + b
diff = y-yout
loss = nn.l2_loss(diff)
optimizer = train.AdamOptimizer()
minimize = train.minimize(optimizer, loss)
sess = Session()
run(sess, global_variables_initializer())
nt = 20000
for i in 1:nt
run(sess, minimize, Dict(x=>x0,yout=>y0))
if i%1000==0
println(i,"\t",run(sess, loss, Dict(x=>x0,yout=>y0)))
end
end
ae1 = run(sess, a1)
ae2 = run(sess, a2)
ae3 = run(sess, a3)
be = run(sess, b)
ye = run(sess, y, Dict(x=>x0,yout=>y0))
println("a1 = ", ae1," a2 = ", ae2)
println("a3 = ", ae3)
println(" b = ", be)
close(sess)
ple = plot(x0,[y0,ye],marker=:circle,label=["Data","Estimation"])
savefig("comparison3.png")
ple
となり、出力は
1000 38.72959372420636
2000 14.703627447113401
3000 13.432322808936945
4000 12.615366011127026
5000 11.748108002654657
6000 11.040931260961438
7000 10.645708109304147
8000 10.520118027859194
9000 10.503778303825083
10000 10.503275443619287
11000 10.50327398556771
12000 10.503273985489727
13000 10.503274102618724
14000 10.503274291560643
15000 10.50327404149707
16000 10.503273985586867
17000 10.503273985496698
18000 10.503273985489809
19000 10.503273985493022
20000 10.503274006524531
となって、lossが二次関数の時より減っているように見える。グラフも
となり、いい感じに見える。しかし、まだうまく合うものがあるかもしれない。4次関数を使ってみよう。
using TensorFlow
x = placeholder(Float64)
yout = placeholder(Float64)
a1 = Variable(2.0)
a2 = Variable(2.0)
a3 = Variable(2.0)
a4 = Variable(2.0)
b = Variable(0.0)
y = a1.*x+a2.*x.^2 +a3.*x.^3 +a4.*x.^4+b
diff = y-yout
loss = nn.l2_loss(diff)
optimizer = train.AdamOptimizer()
minimize = train.minimize(optimizer, loss)
sess = Session()
run(sess, global_variables_initializer())
nt = 20000
for i in 1:nt
run(sess, minimize, Dict(x=>x0,yout=>y0))
if i%1000==0
println(i,"\t",run(sess, loss, Dict(x=>x0,yout=>y0)))
end
end
ae1 = run(sess, a1)
ae2 = run(sess, a2)
ae3 = run(sess, a3)
ae4 = run(sess, a4)
be = run(sess, b)
ye = run(sess, y, Dict(x=>x0,yout=>y0))
println("a1 = ", ae1," a2 = ", ae2)
println("a3 = ", ae3," a4 = ", ae4)
println(" b = ", be)
close(sess)
ple = plot(x0,[y0,ye],marker=:circle,label=["Data","Estimation"])
savefig("comparison4.png")
ple
1000 292.4012075333036
2000 54.91688158042079
3000 30.3129919905984
4000 22.260492197924318
5000 16.68915508888759
6000 13.581239652553963
7000 12.107794990192408
8000 11.180824199691886
9000 10.309787404583048
10000 9.52945773262936
11000 8.996510842589858
12000 8.753597234725168
13000 8.694117952080358
14000 8.688845495294744
15000 8.688757399168221
16000 8.688757306833963
17000 8.688757306833109
18000 8.688757306833107
19000 8.688757306833104
20000 8.688757306833105
となり、lossはさらに減っているように見える。グラフも
となっていて、よくフィッティングできているように見える。
オリジナルの関数を知らなければ、この4次関数を採用しても良いように見える。
過学習
もちろん、4次関数ではうまくいかないことは、オリジナルの関数を知っていればわかる。
これを確かめてみよう。
これまでは、$-2 \le x \le 2$のデータしか持っていなかったが、新しく、$2 \le x \le 4$のデータを得られたとしよう。この新しいデータの分布を、4次関数はちゃんと再現できるだろうか?
using Plots
ENV["PLOTS_TEST"] = "true"
gr()
n = 10
x0 = linspace(-2,2,n)
println(x0)
x0 = [x0
linspace(2,4,n)]
a0 = 3.0
a1= 2.0
b0 = 1.0
y0 = a0.*x0 + a1.*x0.^2 + b0 + totalnoise
ys = ae1.*x0+ae2.*x0.^2+ae3.*x0.^3+ae4.*x0.^4+be
pls = plot(x0,[y0,ys],marker=:circle,label=["Data","Estimation"])
savefig("comparison_2.png")
pls
となる。見てわかるように、$-2 \le x \le 2$までならよく合っているが、その先は一気にダメになっている。
一般的に、決めるべきパラメータ(ここでは、多項式の係数)が多すぎる場合、ノイズの部分も含めて「合わせにいってしまう」ので、本来の関数とは大きくかけ離れたものが得られてしまうことがある。これを「過学習」と呼ぶ。英語だとover fittingと呼ぶ。
機械学習の言葉で言うと、$-2 \le x \le 2$のデータを「トレーニングデータ」、$2 \le x \le 4$のデータを「テストデータ」と呼ぶ。学習(係数決め)に用いた「トレーニングデータ」のlossが低くなったからといって、「テストデータ」のlossが低いとは限らない。きちんとテストデータでのlossが低くなることを確かめなければ、学習モデルは常に過学習の危険にさらされている。そのため、テストデータとトレーニングデータを入れ替えて学習をしたり、ランダムにデータを取ってきてトレーニングデータにしたり、など、様々な工夫が存在する。
追記
「区間を増やせばどんな関数だって合わないので、過学習の例としてはあまり良くないのでは」という指摘をもらった。確かにその通りなので、次の記事で線形基底を使った線形回帰を用いて、過学習についてもう一度やってみることにする。