1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Flux.jlで途中の出力を分ける

Last updated at Posted at 2023-05-17

と言う質問があったので、気になったのでやってみました。

環境

  • Julia 1.8.5

コード

ポイントは、出力を分けることですが、カスタムレイヤーを作って分けてみました。このカスタムレイヤーは、5x6行列のAをかけることで、上半分の5成分を、1x6行列をかけることで下半分の1成分を取り出しています。

using Flux

struct Splitlayer
    m::Int64
    n::Int64
    A
    B
end
Flux.trainable(a::Splitlayer) = ()
function Splitlayer(n,m)
    A = zeros(Bool,m,n)
    B = zeros(Bool,n-m,n)
    for i = 1:m
        A[i,i] = 1;
    end
    for i = 1:n-m
        B[i,i+m] = 1;
    end
    return Splitlayer(m,n,A,B)
end
function (s::Splitlayer)(x)
    return (s.A*x,s.B*x)
end
Flux.@functor Splitlayer

function test()
    x = rand(Float32,5)
    modeltest = Chain(Dense(5,8,relu),Dense(8,6))
    y1= modeltest(x)
    println(y1)
    println(Splitlayer(6,5)(y1))

    
    mini1 = Chain(Dense(5,5),Dense(5,5))
    mini2 = Chain(x -> tanh.(x))
    model = Chain(Dense(5,8,relu),Dense(8,6),Splitlayer(6,5),((x1,x2)::Tuple) -> (mini1(x1),mini2(x2)),((x1,x2)::Tuple) -> cat(x1,x2,dims=1),x -> sum(x))
    y = model(x)
    println(gradient(model,x))
end 

test()

出力は

Float32[-0.73190856, 0.55760396, -0.2940486, -0.10590199, -0.122594506, 0.5887993]
(Float32[-0.73190856, 0.55760396, -0.2940486, -0.10590199, -0.122594506], Float32[0.5887993])
(Float32[0.39188048, 0.36202207, -0.066422, 0.034748882, -0.11328415],)

となります。ちゃんと微分できているようです。

トレーニング

訓練をしたい場合の例はこちらです。Optimisersを使って勾配を更新してみています。

using Flux
using Optimisers

struct Splitlayer
    m::Int64
    n::Int64
    A
    B
end
Flux.trainable(a::Splitlayer) = ()
function Splitlayer(n,m)
    A = zeros(Bool,m,n)
    B = zeros(Bool,n-m,n)
    for i = 1:m
        A[i,i] = 1;
    end
    for i = 1:n-m
        B[i,i+m] = 1;
    end
    return Splitlayer(m,n,A,B)
end
function (s::Splitlayer)(x)
    return (s.A*x,s.B*x)
end
Flux.@functor Splitlayer



function test()
    x = rand(Float32,5)
    modeltest = Chain(Dense(5,8,relu),Dense(8,6))
    y1= modeltest(x)
    println(y1)
    println(Splitlayer(6,5)(y1))
    
    mini1 = Chain(Dense(5,5),Dense(5,5))
    mini2 = Chain(x -> tanh.(x))
    model = Chain(Dense(5,8,relu),Dense(8,6),Splitlayer(6,5),((x1,x2)::Tuple) -> (mini1(x1),mini2(x2)),((x1,x2)::Tuple) -> cat(x1,x2,dims=1),x -> sum(x))
    y = model(x)

    θ, re = Flux.destructure(model)
    function loss(θ,x,y)
        model = re(θ)
        k = model(x)
        return (k - y)^2
    end
    state = Optimisers.setup(Optimisers.Adam(), θ) 
    println(gradient(model,x))

    y0 = 0.3
    for epoch in 1:1000
        grads = gradient(p -> loss(p,x,y0), θ)[1]
        state, θ = Optimisers.update(state, θ,  grads) 
        println(model(x))
    end


end 
1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?