4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

JuliaでTensorFlow その2: 線形回帰をやってみる

Last updated at Posted at 2018-04-09

Julia 0.6.2でTensorFlowを使う。その2。
その1は
https://qiita.com/cometscome_phys/items/358bc4a1feaec1c7fa14
ここでは、簡単な線形回帰をやってみる。

#元のデータの生成
まずはじめに、元のデータを
$$
y = ax+b +{\rm noise}
$$
として作る。ここで、noiseを適当に発生させておくことで、データが直線からずれることになる。
このようなズレがある場合に、元の直線を推測できるか($a$と$b$を推測できるか)、という問題。
元の直線は

original.jl
using Plots
ENV["PLOTS_TEST"] = "true"
gr()
n = 10
x0 = linspace(-1,1,n)
noise = 0.5*rand(linspace(-1,1),n)
a0 = 3.0
b0 = 1.0
y0 = a0.*x0 + b0 + noise
plot(x0,y0,marker=:circle,label="Data")
savefig("data.png")

で描くことができる。
これを実行すると
data.png
というグラフが得られる。
なお、Plotsのバックエンドのgrが6倍大きな図を作ってしまう問題があるので、ENV["PLOTS_TEST"] = "true" を使っている。詳しくは
http://nbviewer.jupyter.org/gist/genkuroki/3d6dbf52a3e52eb7c664bc88632c81d3
を参照。

グラフの設計

次に、その1と同じようにグラフを設計してみよう。
まず、$x_i$をインプットデータ、$y_i$をアウトプットデータとして、この$(x_i,y_i)$の組が複数ある状況を考える(上のコードでは10個)。
そして、
データが
$$
y_i = a x_i + b
$$
という直線に従っていると仮定する。
この仮定のもと、一番もっともらしい$a$と$b$を求めるのが線形回帰である。
そして、二乗誤差:
$$
{\rm loss} = \sum_i (y - y_i)^2
$$
が一番小さくなるような$a$と$b$が、一番もっともらしいとする。

TensorFlowでは、何かの関数を最小化することが簡単にできる。
何を最小化するか、どのように最小化するかの情報を設定しておけばよい。
その1の小人の例で例えると、小人Bに「最小化しろ」といえばあとはよしなにやってくれるのである。
まず、小人Aへの指令(グラフ作成)は

test1.jl
using TensorFlow

x = placeholder(Float64)
yout = placeholder(Float64)
a = Variable(2.0)
b = Variable(0.0)
y = a.*x+b
diff = y-yout
loss = nn.l2_loss(diff)
optimizer = train.AdamOptimizer()
minimize = train.minimize(optimizer, loss)

となる。ここでは、minimizeを小人Bに指定すれば、lossを最小化してくれることになる。

グラフの実行

次に、実際の計算を小人Bにやらせるコードは

test1.jl
sess = Session()
run(sess, global_variables_initializer())
nt = 10000
for i in 1:nt
    run(sess, minimize, Dict(x=>x0,yout=>y0))
    if i%1000==0
        println(i,"\t",run(sess, loss, Dict(x=>x0,yout=>y0)))
    end
end

となる。
出力結果は

1000	0.5525778742090897
2000	0.2062854331360348
3000	0.20391311222929595
4000	0.20391251120152726
5000	0.20391251120114107
6000	0.203912511201141
7000	0.2039125112011411
8000	0.20391251120114104
9000	0.20391251120114096
10000	0.20391251120114093

となる。
ここで、最小化は逐次的に行われるので、何回も繰り返すことで少しずつlossが小さくなっていくことに注意。ある程度まで回数をこなすと、lossの値は変わらなくなる。この時、今仮定している線形のモデルでの最良の$a$と$b$が得られている可能性が高い(今後言及する「過学習」が起きた場合には、その限りではない)。

最後に、$a$や$b$の値を表示させ、グラフをプロットするには、

test1.jl
ae = run(sess, a)
be = run(sess, b)
ye = run(sess, y, Dict(x=>x0,yout=>y0))
println("a = ", ae," b = ", be)
close(sess)
plot(x0,[y0,ye],label=["Data","Estimation"],marker=:circle)

とすればよい。そして、

a = 3.0662337662337653 b = 0.8816326530612243

という値が得られる。
グラフは

comparison.png

となる。
それなりにいい感じにフィッティングされている。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?