3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

JuliaでTensorFlow その3: 過学習について

Last updated at Posted at 2018-04-09

Julia 0.6.2でTensorFlowを使う。その3。
その1
https://qiita.com/cometscome_phys/items/358bc4a1feaec1c7fa14
その2:線形回帰をやってみる
https://qiita.com/cometscome_phys/items/244cfed8ab309156735c

今回は過学習について。

#多項式フィッティング
その2では線形回帰:
$$
y = ax+b
$$
を用いてデータをフィッティングした。
世の中の様々な関数はもちろん線形なものとは限らないので、関数をもう少し高級なものにしてみたい。すぐに思いつくのは、
$$
y = a_1 x + a_2 x^2 + b
$$

$$
y = a_1 x + a_2 x^2 + a_3 x^3 + b
$$
などの関数だろう。多項式を使う場合には、一般的には
$$
y = \sum_{i=1}^n a_i x^i + b
$$
としておけば、$n$次の多項式でのフィッティングができるはずである。

オリジナルデータ

今回のオリジナルデータは、
$$
y = a_1 x + a_2 x^2 + b + {\rm noise}
$$
という2次関数としておこう。$-2 \le x \le 2$の範囲内で、この関数は

original.jl
using Plots
ENV["PLOTS_TEST"] = "true"
gr()
n = 10
x0 = linspace(-2,2,n)
totalnoise = 3*rand(linspace(-1,1),2n)
noise = totalnoise[1:n]
a0 = 3.0
a1= 2.0
b0 = 1.0
y0 = a0.*x0 + a1.*x0.^2 + b0 + noise
pl=plot(x0,y0,marker=:circle,label="Data")
savefig("data2.png")
pl

data2.png

となっている。
その2と同じようにグラフを設計してみよう。

2次関数の場合

test1.jl
using TensorFlow

x = placeholder(Float64)
yout = placeholder(Float64)
a1 = Variable(2.0)
a2 = Variable(2.0)
a3 = Variable(2.0)
a4 = Variable(2.0)
b = Variable(0.0)
y = a1.*x+a2.*x.^2 + b 
diff = y-yout
loss = nn.l2_loss(diff)
optimizer = train.AdamOptimizer()
minimize = train.minimize(optimizer, loss)

これを使ってグラフを計算してみる。

test2.jl
sess = Session()
run(sess, global_variables_initializer())
nt = 20000
for i in 1:nt
    run(sess, minimize, Dict(x=>x0,yout=>y0))
    if i%1000==0
        println(i,"\t",run(sess, loss, Dict(x=>x0,yout=>y0)))
    end
end
ae1 = run(sess, a1)
ae2 = run(sess, a2)
be = run(sess, b)
ye = run(sess, y, Dict(x=>x0,yout=>y0))
println("a1 = ", ae1," a2 = ", ae2)
println(" b = ", be)
close(sess)
ple = plot(x0,[y0,ye],marker=:circle,label=["Data","Estimation"])
savefig("comparison2.png")
ple

その結果は

1000	12.846654779416948
2000	11.229496095926173
3000	11.17696202312561
4000	11.17684298519887
5000	11.176842982090793
6000	11.176842982090797
7000	11.176842982090797
8000	11.176842982090793
9000	11.176842982090797
10000	11.176842982090795
11000	11.176842982090793
12000	11.176842982090797
13000	11.17684298874014
14000	11.176842983351635
15000	11.176842982090907
16000	11.176842982090793
17000	11.176842982090793
18000	11.176842982090793
19000	11.176842982090797
20000	11.17684298209079

comparison2.png

となる。まあ悪くない気がする。

高次の多項式の場合

次に、元のデータが二次関数であると知らなかったとしよう。その場合、このデータはまだまだ良いフィッティング関数でフィッティングできるのではないか?と考えるかもしれない。
ということで、3次関数でフィッティングしてみると、

test3.jl
using TensorFlow

x = placeholder(Float64)
yout = placeholder(Float64)
a1 = Variable(2.0)
a2 = Variable(2.0)
a3 = Variable(2.0)
a4 = Variable(2.0)
b = Variable(0.0)
y = a1.*x+a2.*x.^2 +a3.*x.^3 + b
diff = y-yout
loss = nn.l2_loss(diff)
optimizer = train.AdamOptimizer()
minimize = train.minimize(optimizer, loss)

sess = Session()
run(sess, global_variables_initializer())
nt = 20000
for i in 1:nt
    run(sess, minimize, Dict(x=>x0,yout=>y0))
    if i%1000==0
        println(i,"\t",run(sess, loss, Dict(x=>x0,yout=>y0)))
    end
end
ae1 = run(sess, a1)
ae2 = run(sess, a2)
ae3 = run(sess, a3)
be = run(sess, b)
ye = run(sess, y, Dict(x=>x0,yout=>y0))
println("a1 = ", ae1," a2 = ", ae2)
println("a3 = ", ae3)
println(" b = ", be)
close(sess)
ple = plot(x0,[y0,ye],marker=:circle,label=["Data","Estimation"])
savefig("comparison3.png")
ple

となり、出力は

1000	38.72959372420636
2000	14.703627447113401
3000	13.432322808936945
4000	12.615366011127026
5000	11.748108002654657
6000	11.040931260961438
7000	10.645708109304147
8000	10.520118027859194
9000	10.503778303825083
10000	10.503275443619287
11000	10.50327398556771
12000	10.503273985489727
13000	10.503274102618724
14000	10.503274291560643
15000	10.50327404149707
16000	10.503273985586867
17000	10.503273985496698
18000	10.503273985489809
19000	10.503273985493022
20000	10.503274006524531

となって、lossが二次関数の時より減っているように見える。グラフも

comparison3.png

となり、いい感じに見える。しかし、まだうまく合うものがあるかもしれない。4次関数を使ってみよう。

test4.jl
using TensorFlow

x = placeholder(Float64)
yout = placeholder(Float64)
a1 = Variable(2.0)
a2 = Variable(2.0)
a3 = Variable(2.0)
a4 = Variable(2.0)
b = Variable(0.0)
y = a1.*x+a2.*x.^2 +a3.*x.^3 +a4.*x.^4+b
diff = y-yout
loss = nn.l2_loss(diff)
optimizer = train.AdamOptimizer()
minimize = train.minimize(optimizer, loss)

sess = Session()
run(sess, global_variables_initializer())
nt = 20000
for i in 1:nt
    run(sess, minimize, Dict(x=>x0,yout=>y0))
    if i%1000==0
        println(i,"\t",run(sess, loss, Dict(x=>x0,yout=>y0)))
    end
end
ae1 = run(sess, a1)
ae2 = run(sess, a2)
ae3 = run(sess, a3)
ae4 = run(sess, a4)
be = run(sess, b)
ye = run(sess, y, Dict(x=>x0,yout=>y0))
println("a1 = ", ae1," a2 = ", ae2)
println("a3 = ", ae3," a4 = ", ae4)
println(" b = ", be)
close(sess)
ple = plot(x0,[y0,ye],marker=:circle,label=["Data","Estimation"])
savefig("comparison4.png")
ple
1000	292.4012075333036
2000	54.91688158042079
3000	30.3129919905984
4000	22.260492197924318
5000	16.68915508888759
6000	13.581239652553963
7000	12.107794990192408
8000	11.180824199691886
9000	10.309787404583048
10000	9.52945773262936
11000	8.996510842589858
12000	8.753597234725168
13000	8.694117952080358
14000	8.688845495294744
15000	8.688757399168221
16000	8.688757306833963
17000	8.688757306833109
18000	8.688757306833107
19000	8.688757306833104
20000	8.688757306833105

となり、lossはさらに減っているように見える。グラフも
comparison4.png
となっていて、よくフィッティングできているように見える。
オリジナルの関数を知らなければ、この4次関数を採用しても良いように見える。

過学習

もちろん、4次関数ではうまくいかないことは、オリジナルの関数を知っていればわかる。
これを確かめてみよう。
これまでは、$-2 \le x \le 2$のデータしか持っていなかったが、新しく、$2 \le x \le 4$のデータを得られたとしよう。この新しいデータの分布を、4次関数はちゃんと再現できるだろうか?

test.jl
using Plots
ENV["PLOTS_TEST"] = "true"
gr()
n = 10
x0 = linspace(-2,2,n)
println(x0)
x0 = [x0
    linspace(2,4,n)]


a0 = 3.0
a1= 2.0
b0 = 1.0
y0 = a0.*x0 + a1.*x0.^2 + b0 + totalnoise
ys = ae1.*x0+ae2.*x0.^2+ae3.*x0.^3+ae4.*x0.^4+be 
pls = plot(x0,[y0,ys],marker=:circle,label=["Data","Estimation"])
savefig("comparison_2.png")
pls 

グラフは
comparison_2.png

となる。見てわかるように、$-2 \le x \le 2$までならよく合っているが、その先は一気にダメになっている。

一般的に、決めるべきパラメータ(ここでは、多項式の係数)が多すぎる場合、ノイズの部分も含めて「合わせにいってしまう」ので、本来の関数とは大きくかけ離れたものが得られてしまうことがある。これを「過学習」と呼ぶ。英語だとover fittingと呼ぶ。
機械学習の言葉で言うと、$-2 \le x \le 2$のデータを「トレーニングデータ」、$2 \le x \le 4$のデータを「テストデータ」と呼ぶ。学習(係数決め)に用いた「トレーニングデータ」のlossが低くなったからといって、「テストデータ」のlossが低いとは限らない。きちんとテストデータでのlossが低くなることを確かめなければ、学習モデルは常に過学習の危険にさらされている。そのため、テストデータとトレーニングデータを入れ替えて学習をしたり、ランダムにデータを取ってきてトレーニングデータにしたり、など、様々な工夫が存在する。

追記

「区間を増やせばどんな関数だって合わないので、過学習の例としてはあまり良くないのでは」という指摘をもらった。確かにその通りなので、次の記事で線形基底を使った線形回帰を用いて、過学習についてもう一度やってみることにする。

3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?