背景
Juliaにはいくつかニューラルネットワークライブラリがある。
私は昔Mocha.jlを使っていたがJulia1.0以降だと今の所非推奨ということなのでFlux.jlに移民した。
自分の中で新しいニューラルネットワークのライブラリを使う時Sin波を学習するという儀式があるので、儀式を行った。
コード
using Plots
gr()
using Flux
using Statistics
using Flux.Tracker: TrackedReal, data
using Flux: mse
using Base.Iterators: repeated, flatten
# 訓練データ
N = 100
X = range(0, stop = pi, length = N)
Y = sin.(X)
# 訓練データをプロットしておく
plot(X, Y)
data_x = [[x] for x in X]
data_y = [[y] for y in Y]
# batch処理すべきだが、めんどうなのでrepeatedでごまかした。
# Model-Zooもrepeatedでなんとかしてる奴あるしいいよね。
data_xf = Iterators.flatten(repeated(data_x, 100))
data_yf = Iterators.flatten(repeated(data_y, 100))
入力データは[(入力の配列, 出力の配列)]な形式
dataset = zip(data_xf, data_yf)
# モデル
m = Chain(
Dense(1, 20, relu),
Dense(20, 1, σ))
loss(x, y) = mse(m(x), y)
opt = Descent()
Flux.train!(loss, params(m), dataset, opt)
Nt = 100
Xt = range(0, stop = pi, length = Nt)
input_xt = [[x] for x in Xt]
expect_yt = m.(input_xt)
Yt = collect(Iterators.flatten(expect_yt))
# 結果はTrackedRealという型に入ってくるため数字だけ抜き出す。
Yt2 = data.(Yt)
plot!(Xt, Yt2)
png("result.png")
結果
どうでもいい話
彼女欲しいので女性紹介してください。