と言う質問があったので、気になったのでやってみました。
環境
- Julia 1.8.5
コード
ポイントは、出力を分けることですが、カスタムレイヤーを作って分けてみました。このカスタムレイヤーは、5x6行列のAをかけることで、上半分の5成分を、1x6行列をかけることで下半分の1成分を取り出しています。
using Flux
struct Splitlayer
m::Int64
n::Int64
A
B
end
Flux.trainable(a::Splitlayer) = ()
function Splitlayer(n,m)
A = zeros(Bool,m,n)
B = zeros(Bool,n-m,n)
for i = 1:m
A[i,i] = 1;
end
for i = 1:n-m
B[i,i+m] = 1;
end
return Splitlayer(m,n,A,B)
end
function (s::Splitlayer)(x)
return (s.A*x,s.B*x)
end
Flux.@functor Splitlayer
function test()
x = rand(Float32,5)
modeltest = Chain(Dense(5,8,relu),Dense(8,6))
y1= modeltest(x)
println(y1)
println(Splitlayer(6,5)(y1))
mini1 = Chain(Dense(5,5),Dense(5,5))
mini2 = Chain(x -> tanh.(x))
model = Chain(Dense(5,8,relu),Dense(8,6),Splitlayer(6,5),((x1,x2)::Tuple) -> (mini1(x1),mini2(x2)),((x1,x2)::Tuple) -> cat(x1,x2,dims=1),x -> sum(x))
y = model(x)
println(gradient(model,x))
end
test()
出力は
Float32[-0.73190856, 0.55760396, -0.2940486, -0.10590199, -0.122594506, 0.5887993]
(Float32[-0.73190856, 0.55760396, -0.2940486, -0.10590199, -0.122594506], Float32[0.5887993])
(Float32[0.39188048, 0.36202207, -0.066422, 0.034748882, -0.11328415],)
となります。ちゃんと微分できているようです。
トレーニング
訓練をしたい場合の例はこちらです。Optimisersを使って勾配を更新してみています。
using Flux
using Optimisers
struct Splitlayer
m::Int64
n::Int64
A
B
end
Flux.trainable(a::Splitlayer) = ()
function Splitlayer(n,m)
A = zeros(Bool,m,n)
B = zeros(Bool,n-m,n)
for i = 1:m
A[i,i] = 1;
end
for i = 1:n-m
B[i,i+m] = 1;
end
return Splitlayer(m,n,A,B)
end
function (s::Splitlayer)(x)
return (s.A*x,s.B*x)
end
Flux.@functor Splitlayer
function test()
x = rand(Float32,5)
modeltest = Chain(Dense(5,8,relu),Dense(8,6))
y1= modeltest(x)
println(y1)
println(Splitlayer(6,5)(y1))
mini1 = Chain(Dense(5,5),Dense(5,5))
mini2 = Chain(x -> tanh.(x))
model = Chain(Dense(5,8,relu),Dense(8,6),Splitlayer(6,5),((x1,x2)::Tuple) -> (mini1(x1),mini2(x2)),((x1,x2)::Tuple) -> cat(x1,x2,dims=1),x -> sum(x))
y = model(x)
θ, re = Flux.destructure(model)
function loss(θ,x,y)
model = re(θ)
k = model(x)
return (k - y)^2
end
state = Optimisers.setup(Optimisers.Adam(), θ)
println(gradient(model,x))
y0 = 0.3
for epoch in 1:1000
grads = gradient(p -> loss(p,x,y0), θ)[1]
state, θ = Optimisers.update(state, θ, grads)
println(model(x))
end
end