この記事はJuliaで機械学習:Flux.jlではなくLux.jlを使ってみるの続編です。個人的には、
f(\vec{x}) = \sum_{i=1}^N f_i(\vec{x})
のように書かれた非線形関数のフィッティングを行いたいので、それをするためのネットワークを作ってみることにします。公式のサイトを参考にしました。
バージョン
- Julia 1.10.5
- Lux v1.1.0
カスタムなネットワーク
今回は、ニューラルネットワークを$f_i^{\rm eff}(\vec{x})$として、それを$N$個定義して、それらの和を取ったようなモデル:
f^{\rm eff}(\vec{x}) = \sum_{i=1}^N f_i^{\rm eff}(\vec{x})
を作ることを考えます。特に、$N$にどんな数が来ても対応できるようにしたいところです。
これを実現するには、
struct BPChain{L<:NamedTuple} <: Lux.AbstractLuxWrapperLayer{:layers}
layers::L
end
という型を定義します。このBPChain
は名前付きタプルをフィールドに持つ型です。そして、Lux.AbstractLuxWrapperLayer{:layers}
は、この型がラッパーレイヤーであることを意味しているらしいです。つまり、すでにLuxで定義してあるレイヤーのラッパーのようなものであるときには、これを指定します。
この型のフィールドは名前付きタプルですから、コンストラクタにも名前付きタプルを入れることになります。名前付きタプルとは、
A = (a = 3,b = 4)
のようなもので、aやbには
A.a
でアクセスできます。今回は名前付きタプルにはLuxのレイヤーを入れればよいので、
model1 = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 1))
model2 = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 1))
model = (Ti = model1,O=model2)
とします。これで名前はTiとOのレイヤーが定義されたことになります(TiとOという名前自体は気にしないでください)。名前はいつもTiとOとは限りませんので、配列から自動的に名前付きタプルを作れるように、
atomkinds = ["Ti", "O"]
keys = Tuple(Symbol.(atomkinds))
nt = NamedTuple{keys}((model1, model2))
として、ntと言う名前付きタプルを定義しました。これで、
model = BPChain(nt)
display(model)
は
BPChain(
Ti = Chain(
layer_1 = Dense(2 => 10, relu), # 30 parameters
layer_2 = Dense(10 => 10, relu), # 110 parameters
layer_3 = Dense(10 => 1), # 11 parameters
),
O = Chain(
layer_1 = Dense(2 => 10, relu), # 30 parameters
layer_2 = Dense(10 => 10, relu), # 110 parameters
layer_3 = Dense(10 => 1), # 11 parameters
),
) # Total: 302 parameters,
# plus 0 states.
という出力になります。このような書き方をすれば、TiとOの他に例えばHとかSとかがあっても、同じように名前付きタプルで定義すれば同じようにモデルを定義できます。
次に、このモデルの動作を定義します。まだ、TiとOのネットワークをどう使うのか定義していませんので。これは、
function (l::BPChain)(xin, ps, st::NamedTuple)
n1, n2 = size(xin)
x = zeros(1, n2)
for name in keys(l.layers)
model_i = getfield(l.layers, name)
ps_i = getfield(ps, name)
st_i = getfield(st, name)
x_i, st_ = Lux.apply(model_i, xin, ps_i, st_i)\
st = merge(st, NamedTuple{(name,)}((st_,)))
x += x_i
end
return x, st
end
としました。名前付きタプルのそれぞれ取り出してループしていまして、それぞれのモデルをmodel_iとして入力xin
から出力x_i
を作り、和をとって最終出力x
を返しています。
あとは、前の記事と同じように学習します。
フィッティング例
対象とする関数
対象とする関数は、
function main()
f_1(x) = x[1] * x[2] + cos(3 * x[1]) + exp(x[2] / 5) * x[1] + tanh(x[2]) * cos(3 * x[2])
f_2(x) = -x[1] * x[2] + sin(2 * x[2]) + exp(x[1] / 3) * x[2] + tanh(x[1]) * cos(7 * x[1])
f(x) = f_1(x) + f_2(x)
num = 30
x = range(-2, 2, length=num)
y = range(-2, 2, length=num)
z = [f([i, j]) for i in x, j in y]'
p = plot(x, y, z, st=:wireframe, zlims=(-3, 8))
savefig("bp_original.png")
end
とします。$N=2$ですね。形は
こんな感じです。
コード
それでは、これをフィッティングしてみましょう。コードは以下の通りです。
using Plots
using MLUtils
using Lux
#using Metal
using Optimisers
using Random
using Zygote
const loss_function = MSELoss()
const dev_cpu = cpu_device()
const dev_gpu = gpu_device()
function make_data(f, num)
#num = 47
#numt = 19
x = range(-2, 2, length=num)
y = range(-2, 2, length=num)
count = 0
z = Float32[]
for i = 1:num
for j = 1:num
count += 1
push!(z, f([x[i], y[j]]))
end
end
return x, y, z
end
function make_inputoutput(x, y, z)
count = 0
numx = length(x)
numy = length(y)
input = zeros(Float64, 2, numx * numy)
output = zeros(Float64, 1, numx * numy)
count = 0
for i = 1:numx
for j = 1:numy
count += 1
input[1, count] = x[i]
input[2, count] = y[j]
output[1, count] = z[count]
end
end
return input, output
end
function training!(tstate, train_loader, test_loader, vjp_rule, epochs)
for epoch in 1:epochs
for (x, y) in train_loader
data = (x, y) |> gpu_device()
_, loss, _, tstate = Training.single_train_step!(vjp_rule, loss_function, data, tstate)
end
if epoch % 100 == 1 || epoch == epochs
loss = 0.0
for (x, y) in test_loader
y = y |> dev_gpu
x = x |> dev_gpu
y_pred = Lux.apply(tstate.model, x, tstate.parameters, tstate.states)[1]
loss += loss_function(y_pred, y)
end
loss = loss / length(test_loader)
println("Epoch: $epoch \t Loss: $loss")
end
end
end
struct BPChain{L<:NamedTuple} <: Lux.AbstractLuxWrapperLayer{:layers}
layers::L
end
function (l::BPChain)(xin, ps, st::NamedTuple)
n1, n2 = size(xin)
x = zeros(1, n2)
for name in keys(l.layers)
#println(name)
model_i = getfield(l.layers, name)
ps_i = getfield(ps, name)
st_i = getfield(st, name)
x_i, st_ = Lux.apply(model_i, xin, ps_i, st_i)
#println(getfield(l.layers, name))
st = merge(st, NamedTuple{(name,)}((st_,)))
x += x_i
end
return x, st
end
function main()
f_1(x) = x[1] * x[2] + cos(3 * x[1]) + exp(x[2] / 5) * x[1] + tanh(x[2]) * cos(3 * x[2])
f_2(x) = -x[1] * x[2] + sin(2 * x[2]) + exp(x[1] / 3) * x[2] + tanh(x[1]) * cos(7 * x[1])
f(x) = f_1(x) + f_2(x)
num = 30
x = range(-2, 2, length=num)
y = range(-2, 2, length=num)
z = [f([i, j]) for i in x, j in y]'
p = plot(x, y, z, st=:wireframe, zlims=(-3, 8))
savefig("bp_original.png")
xdata, ydata, zdata = make_data(f, num)
input_data, output_data = make_inputoutput(xdata, ydata, zdata)
numtrain = 47
xtrain, ytrain, ztrain = make_data(f, numtrain)
numtest = 19
xtest, ytest, ztest = make_data(f, numtest)
input_train, output_train = make_inputoutput(xtrain, ytrain, ztrain)
#display(output_train)
input_test, output_test = make_inputoutput(xtest, ytest, ztest)
batchsize = 128
train_loader = DataLoader((input_train, output_train), batchsize=batchsize, shuffle=true) #MLUtils
test_loader = DataLoader((input_test, output_test), batchsize=1, shuffle=false) #MLUtils
for (x, y) in train_loader
#display(x)
#display(y)
end
atomkinds = ["Ti", "O"]
keys = Tuple(Symbol.(atomkinds))
display(keys)
model1 = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 1))
model2 = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 1))
nt = NamedTuple{keys}((model1, model2))
display(nt)
model = BPChain(nt)
display(model)
rng = MersenneTwister(0)
ps, st = Lux.setup(rng, model)
#y = model([0.2, 0.3], ps, st)[1]
display(ps)
#model = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 10, relu), Dense(10, 1))
#display(model)
opt = Adam() #Optimisers.jl
display(opt)
rng = MersenneTwister() #Random
Random.seed!(rng, 12345)
ps, st = Lux.setup(rng, model) |> dev_gpu
#display(ps)
#display(st)
tstate = Training.TrainState(model, ps, st, opt)
display(tstate)
loss = 0.0
for (x, y) in test_loader
y = y |> dev_gpu
x = x |> dev_gpu
y_pred = Lux.apply(tstate.model, x, tstate.parameters, tstate.states)[1]
#y_pred = tstate.model(x, tstate.parameters, tstate.states)
#display(y)
#display(y_pred)
loss += loss_function(y_pred, y)
end
println(length(test_loader))
println("initial loss $(loss/length(test_loader))")
vjp_rule = AutoZygote()
epochs = 10000
@time training!(tstate, train_loader, test_loader, vjp_rule, epochs)
output_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(input_data), tstate.parameters, tstate.states)[1])
z_pred = reshape(output_pred, num, num)
#display(y_pred)
p = plot(x, y, z_pred, st=:wireframe, zlims=(-3, 8))
savefig("bp_dense.png")
end
main()
実行して得られたグラフは
となります。ちゃんとフィッティングされています。