1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Juliaでニューラルネットワーク: Flux.jlのMinimal reproducible example

Last updated at Posted at 2023-11-08

Flux.jlはv0.15以降、これまでの標準であったImplicit-styleを用いた記述がサポートされなくなるようです(公式ドキュメント)。新しい記述スタイルに対応したサンプルコードはいまだに多くないため、とりあえず動く、汎用性の高いminimum exampleを参考として載せておきます。

モデル設定

よくあるパーセプトロンのteacher-student model. $\boldsymbol{x} \in \mathbb{R}^d$に対して、

y = W_{teacher} \boldsymbol{x}

によって出力$y\in \mathbb{R}$が定まる。生徒パーセプトロンは教師データを用いて$W_{teacher}$を推論する。

コード

#! teacher-student modelを解く1層パーセプトロン
using Flux, Distributions, LinearAlgebra, Plots, ProgressMeter

mu = 0.0f0
sigma = 1.0f0

d = 100 # dim
N = 1000 # train_num
M = 1000 # test_num

# teacher
W_teacher = rand(Normal(mu, sigma), d) 

# train
x_train = rand(Normal(mu, sigma), d, N) 
y_train = Float32.(W_teacher' * x_train)
train_data = [(x_train[:, i], y_train[i]) for i in 1:N] 
train_loader = Flux.DataLoader((x_train, y_train) |> gpu, batchsize=5, shuffle=true)

# test
x_test = rand(Normal(mu, sigma), d, M) 
y_test = Float32.(W_teacher' * x_test)
test_data = [(x_test[:, i], y_test[i]) for i in 1:M] 


# model and loss
model = Chain(Dense(d, 1, bias=false))  |> gpu
loss(y_hat, y) =  sum((y_hat .- y).^2)   # 全バッチでのlossのsumをとってgradを出す
opt = Flux.setup(Adam(), model) 

# training
n_epoch = 1000
train_loss = zeros(n_epoch)
test_loss = zeros(n_epoch)

function loss_total(x_batch::AbstractMatrix, y_batch::AbstractArray, model, loss)
    y_preds = model(x_batch) 
    return sum(loss.(y_preds, y_batch))
end 

@time @showprogress for i in 1:n_epoch
    for data in train_loader
        input, output = data

        grads = Flux.gradient(model) do m
            y_hat = m(input)
            
            return loss(y_hat, output) 
        end
        Flux.update!(opt, model, grads[1])
    end

    train_loss[i] = loss_total(x_train, y_train, model |> cpu, loss)
    test_loss[i] = loss_total(x_test, y_test, model |> cpu, loss)
end

# 結果の表示
plt = plot(train_loss, label="train")
plot!(plt, test_loss, label="test", title="Loss")
display(plt)

Remark

  • データは自分でFloat32に変更する必要があります。
  • モデルが小規模なためかGPUを使うと逆に遅くなります。
1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?