LoginSignup
1
2

More than 5 years have passed since last update.

Flux.jlでSin波を学習する

Posted at

背景

Juliaにはいくつかニューラルネットワークライブラリがある。
私は昔Mocha.jlを使っていたがJulia1.0以降だと今の所非推奨ということなのでFlux.jlに移民した。
自分の中で新しいニューラルネットワークのライブラリを使う時Sin波を学習するという儀式があるので、儀式を行った。

コード


using Plots
gr()

using Flux
using Statistics
using Flux.Tracker: TrackedReal, data
using Flux: mse
using Base.Iterators: repeated, flatten

# 訓練データ
N = 100
X = range(0, stop = pi, length = N)
Y = sin.(X)

# 訓練データをプロットしておく
plot(X, Y)

data_x = [[x] for x in X]
data_y = [[y] for y in Y]
# batch処理すべきだが、めんどうなのでrepeatedでごまかした。
# Model-Zooもrepeatedでなんとかしてる奴あるしいいよね。
data_xf = Iterators.flatten(repeated(data_x, 100))
data_yf = Iterators.flatten(repeated(data_y, 100))
入力データは[(入力の配列, 出力の配列)]な形式
dataset = zip(data_xf, data_yf)

# モデル
m = Chain(
  Dense(1, 20, relu),
  Dense(20, 1, σ))

loss(x, y) = mse(m(x), y) 

opt = Descent()
Flux.train!(loss, params(m), dataset, opt)

Nt = 100
Xt = range(0, stop = pi, length = Nt)
input_xt = [[x] for x in Xt]
expect_yt = m.(input_xt)

Yt = collect(Iterators.flatten(expect_yt))
# 結果はTrackedRealという型に入ってくるため数字だけ抜き出す。
Yt2 = data.(Yt)

plot!(Xt, Yt2)
png("result.png")

結果

result.png

どうでもいい話

彼女欲しいので女性紹介してください。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2