1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Juliaで機械学習:Lux.jlで少し複雑なネットワークを作ってみる

Posted at

この記事はJuliaで機械学習:Flux.jlではなくLux.jlを使ってみるの続編です。個人的には、

f(\vec{x}) = \sum_{i=1}^N f_i(\vec{x})

のように書かれた非線形関数のフィッティングを行いたいので、それをするためのネットワークを作ってみることにします。公式のサイトを参考にしました。

バージョン

  • Julia 1.10.5
  • Lux v1.1.0

カスタムなネットワーク

今回は、ニューラルネットワークを$f_i^{\rm eff}(\vec{x})$として、それを$N$個定義して、それらの和を取ったようなモデル:

f^{\rm eff}(\vec{x}) = \sum_{i=1}^N f_i^{\rm eff}(\vec{x})

を作ることを考えます。特に、$N$にどんな数が来ても対応できるようにしたいところです。

これを実現するには、

struct BPChain{L<:NamedTuple} <: Lux.AbstractLuxWrapperLayer{:layers}
    layers::L
end

という型を定義します。このBPChainは名前付きタプルをフィールドに持つ型です。そして、Lux.AbstractLuxWrapperLayer{:layers}は、この型がラッパーレイヤーであることを意味しているらしいです。つまり、すでにLuxで定義してあるレイヤーのラッパーのようなものであるときには、これを指定します。
この型のフィールドは名前付きタプルですから、コンストラクタにも名前付きタプルを入れることになります。名前付きタプルとは、

A = (a = 3,b = 4)

のようなもので、aやbには

A.a

でアクセスできます。今回は名前付きタプルにはLuxのレイヤーを入れればよいので、

    model1 = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 1))
    model2 = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 1))
    model = (Ti = model1,O=model2)

とします。これで名前はTiとOのレイヤーが定義されたことになります(TiとOという名前自体は気にしないでください)。名前はいつもTiとOとは限りませんので、配列から自動的に名前付きタプルを作れるように、

    atomkinds = ["Ti", "O"]
    keys = Tuple(Symbol.(atomkinds))
    nt = NamedTuple{keys}((model1, model2))

として、ntと言う名前付きタプルを定義しました。これで、

    model = BPChain(nt)
    display(model)

BPChain(
    Ti = Chain(
        layer_1 = Dense(2 => 10, relu),  # 30 parameters
        layer_2 = Dense(10 => 10, relu),  # 110 parameters
        layer_3 = Dense(10 => 1),       # 11 parameters
    ),
    O = Chain(
        layer_1 = Dense(2 => 10, relu),  # 30 parameters
        layer_2 = Dense(10 => 10, relu),  # 110 parameters
        layer_3 = Dense(10 => 1),       # 11 parameters
    ),
)         # Total: 302 parameters,
          #        plus 0 states.

という出力になります。このような書き方をすれば、TiとOの他に例えばHとかSとかがあっても、同じように名前付きタプルで定義すれば同じようにモデルを定義できます。

次に、このモデルの動作を定義します。まだ、TiとOのネットワークをどう使うのか定義していませんので。これは、

function (l::BPChain)(xin, ps, st::NamedTuple)
    n1, n2 = size(xin)
    x = zeros(1, n2)
    for name in keys(l.layers)
        model_i = getfield(l.layers, name)
        ps_i = getfield(ps, name)
        st_i = getfield(st, name)
        x_i, st_ = Lux.apply(model_i, xin, ps_i, st_i)\
        st = merge(st, NamedTuple{(name,)}((st_,)))
        x += x_i
    end
    return x, st
end

としました。名前付きタプルのそれぞれ取り出してループしていまして、それぞれのモデルをmodel_iとして入力xinから出力x_iを作り、和をとって最終出力xを返しています。
あとは、前の記事と同じように学習します。

フィッティング例

対象とする関数

対象とする関数は、

function main()
    f_1(x) = x[1] * x[2] + cos(3 * x[1]) + exp(x[2] / 5) * x[1] + tanh(x[2]) * cos(3 * x[2])
    f_2(x) = -x[1] * x[2] + sin(2 * x[2]) + exp(x[1] / 3) * x[2] + tanh(x[1]) * cos(7 * x[1])
    f(x) = f_1(x) + f_2(x)

    num = 30
    x = range(-2, 2, length=num)
    y = range(-2, 2, length=num)
    z = [f([i, j]) for i in x, j in y]'

    p = plot(x, y, z, st=:wireframe, zlims=(-3, 8))
    savefig("bp_original.png")
end

とします。$N=2$ですね。形は

bp_original.png

こんな感じです。

コード

それでは、これをフィッティングしてみましょう。コードは以下の通りです。

using Plots
using MLUtils
using Lux
#using Metal
using Optimisers
using Random
using Zygote

const loss_function = MSELoss()

const dev_cpu = cpu_device()
const dev_gpu = gpu_device()

function make_data(f, num)
    #num = 47
    #numt = 19
    x = range(-2, 2, length=num)
    y = range(-2, 2, length=num)

    count = 0
    z = Float32[]
    for i = 1:num
        for j = 1:num
            count += 1
            push!(z, f([x[i], y[j]]))
        end
    end

    return x, y, z
end



function make_inputoutput(x, y, z)
    count = 0
    numx = length(x)
    numy = length(y)
    input = zeros(Float64, 2, numx * numy)
    output = zeros(Float64, 1, numx * numy)
    count = 0
    for i = 1:numx
        for j = 1:numy
            count += 1
            input[1, count] = x[i]
            input[2, count] = y[j]
            output[1, count] = z[count]
        end
    end
    return input, output
end

function training!(tstate, train_loader, test_loader, vjp_rule, epochs)

    for epoch in 1:epochs
        for (x, y) in train_loader
            data = (x, y) |> gpu_device()
            _, loss, _, tstate = Training.single_train_step!(vjp_rule, loss_function, data, tstate)
        end
        if epoch % 100 == 1 || epoch == epochs
            loss = 0.0
            for (x, y) in test_loader
                y = y |> dev_gpu
                x = x |> dev_gpu
                y_pred = Lux.apply(tstate.model, x, tstate.parameters, tstate.states)[1]
                loss += loss_function(y_pred, y)
            end
            loss = loss / length(test_loader)

            println("Epoch: $epoch \t Loss: $loss")
        end
    end

end


struct BPChain{L<:NamedTuple} <: Lux.AbstractLuxWrapperLayer{:layers}
    layers::L
end

function (l::BPChain)(xin, ps, st::NamedTuple)
    n1, n2 = size(xin)
    x = zeros(1, n2)
    for name in keys(l.layers)
        #println(name)
        model_i = getfield(l.layers, name)
        ps_i = getfield(ps, name)
        st_i = getfield(st, name)
        x_i, st_ = Lux.apply(model_i, xin, ps_i, st_i)
        #println(getfield(l.layers, name))
        st = merge(st, NamedTuple{(name,)}((st_,)))
        x += x_i

    end
    return x, st
end


function main()
    f_1(x) = x[1] * x[2] + cos(3 * x[1]) + exp(x[2] / 5) * x[1] + tanh(x[2]) * cos(3 * x[2])
    f_2(x) = -x[1] * x[2] + sin(2 * x[2]) + exp(x[1] / 3) * x[2] + tanh(x[1]) * cos(7 * x[1])
    f(x) = f_1(x) + f_2(x)

    num = 30
    x = range(-2, 2, length=num)
    y = range(-2, 2, length=num)
    z = [f([i, j]) for i in x, j in y]'

    p = plot(x, y, z, st=:wireframe, zlims=(-3, 8))
    savefig("bp_original.png")


    xdata, ydata, zdata = make_data(f, num)
    input_data, output_data = make_inputoutput(xdata, ydata, zdata)

    numtrain = 47
    xtrain, ytrain, ztrain = make_data(f, numtrain)
    numtest = 19
    xtest, ytest, ztest = make_data(f, numtest)

    input_train, output_train = make_inputoutput(xtrain, ytrain, ztrain)
    #display(output_train)
    input_test, output_test = make_inputoutput(xtest, ytest, ztest)
    batchsize = 128

    train_loader = DataLoader((input_train, output_train), batchsize=batchsize, shuffle=true) #MLUtils
    test_loader = DataLoader((input_test, output_test), batchsize=1, shuffle=false) #MLUtils
    for (x, y) in train_loader
        #display(x)
        #display(y)
    end

    atomkinds = ["Ti", "O"]
    keys = Tuple(Symbol.(atomkinds))
    display(keys)
    model1 = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 1))
    model2 = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 1))
    nt = NamedTuple{keys}((model1, model2))
    display(nt)

    model = BPChain(nt)
    display(model)
    rng = MersenneTwister(0)
    ps, st = Lux.setup(rng, model)
    #y = model([0.2, 0.3], ps, st)[1]
    display(ps)



    #model = Chain(Dense(2, 10, relu), Dense(10, 10, relu), Dense(10, 10, relu), Dense(10, 1))
    #display(model)

    opt = Adam() #Optimisers.jl
    display(opt)

    rng = MersenneTwister() #Random
    Random.seed!(rng, 12345)

    ps, st = Lux.setup(rng, model) |> dev_gpu
    #display(ps)
    #display(st)

    tstate = Training.TrainState(model, ps, st, opt)
    display(tstate)


    loss = 0.0
    for (x, y) in test_loader
        y = y |> dev_gpu
        x = x |> dev_gpu
        y_pred = Lux.apply(tstate.model, x, tstate.parameters, tstate.states)[1]
        #y_pred = tstate.model(x, tstate.parameters, tstate.states)
        #display(y)
        #display(y_pred)
        loss += loss_function(y_pred, y)
    end
    println(length(test_loader))
    println("initial loss $(loss/length(test_loader))")

    vjp_rule = AutoZygote()

    epochs = 10000
    @time training!(tstate, train_loader, test_loader, vjp_rule, epochs)


    output_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(input_data), tstate.parameters, tstate.states)[1])
    z_pred = reshape(output_pred, num, num)
    #display(y_pred)

    p = plot(x, y, z_pred, st=:wireframe, zlims=(-3, 8))
    savefig("bp_dense.png")


end
main()

実行して得られたグラフは

bp_dense.png

となります。ちゃんとフィッティングされています。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?