2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

JuliaAdvent Calendar 2024

Day 10

Lux.jlを使って制限ボルツマンマシンで量子スピン系の基底状態を求める

Last updated at Posted at 2024-12-09

Juliaでの機械学習ライブラリといえばFlux.jlが有名ですが、Lux.jlというものもあります。Lux.jlとFlux.jlの違いについては、Juliaで機械学習:Flux.jlではなくLux.jlを使ってみる
で述べました。

簡単に違いをまとめると、

  • Flux.jl: 機械学習パッケージ。これ一つで何でもできる。モデルを作ると、そのモデルには学習可能なパラメータが存在している(ユーザーからパラメータが見えない)。「モデルを更新」して機械学習を行う。

  • Lux.jl: 機械学習のモデルに関するパッケージ。他のパッケージと組み合わせて学習を行う。モデルを作ると、モデルの骨組みだけが作られ、モデルは学習可能なパラメータを持っていない。モデルにパラメータを流し込むことでモデルが値を出力する。「パラメータを更新」して機械学習を行う。

使い方の違いは、

  • Flux.jl: Flux.jlのデフォルトに入っている機能であれば簡単に使える。メジャーなものは何でも揃っている。パラメータを取り出して更新することも可能。基本的に何でもできる

  • Lux.jl: パラメータを別に扱うことで、更新方法を他のパッケージに丸投げすることができる。例えばNeuralODEであれば、微分方程式を数値的に解くパッケージを途中に入れることでパラメータを更新することができる。他のパッケージと組み合わせやすい形になっている

となっています。

さて、Lux.jlでの簡単な例は前の記事で書きましたので、今度はより複雑なモデルを考えてみます。

制限ボルツマンマシン

制限ボルツマンマシンについては、前の記事Juliaを使って制限ボルツマンマシンで量子スピン系の基底状態を求めるを参照してください。

波動関数$\psi(s_1,\cdots,s_N)$の構成を考えます。ここでは、制限ボルツマンマシン(RBM)を使います。RBMは、

\psi(s_1,\cdots,s_N) = \sum_{h_1=\pm 1,\cdots,h_M=\pm 1} \exp \left[ \sum_{j=1}^N a_j s_j + \sum_{i=1}^M b_i h_i + \sum_{i,j} W_{ij} h_i s_j \right]

という関数です(Carleo et al., Science 355,602–606 (2017))。ここで、$h_1,\cdots,h_M$という$M$個の変数は隠れ層と呼ばれるものです。この模型の良いところは、$h_1,\cdots,h_M$という和を解析的に実行できる点でして、

\psi(s_1,\cdots,s_N) = e^{\sum_{j=1}^N a_j s_j} \prod_{i=1}^M 2 \cosh \left( b_i + \sum_j W_{ij} s_j \right) 

という$h_1,\cdots,h_M$を含まない形に変形できます。この関数のパラメータは、$a_j,b_i,W_{ij}$です。
RBMで気をつけなければならない点としては、この関数は、パラメータが実数である限り、常に正であることです。もし波動関数が負になり得る場合には、パラメータを複素数にする必要があります。

Lux.jlでのモデリング

それでは、Lux.jlでRBMを作っていきましょう。Flux.jlとの違いは、モデルそのものにはデータを持たない、ということです。Lux.jlのモデルは、骨組みだけを持ちます。

まず、RBMという型を

struct RBM{F1,F2,F3,T} <: Lux.AbstractLuxLayer
    numspin::Int
    numhidden::Int
    init_a::F1
    init_b::F2
    init_W::F3
end

と定義します。ポイントは、学習するべきパラメータを含まない、ということです。スピンの数や、初期化の方法だけを保持しています。

次に、コンストラクタを

function RBM(numspin, numhidden, T=Float64,
    init_a=glorot_uniform, init_b=glorot_uniform, init_W=glorot_uniform)
    return RBM{typeof(init_a),typeof(init_b),typeof(init_W),T}(numspin, numhidden, init_a, init_b, init_W)
end

とします。ここで、初期化のデフォルトとしてglorot_uniformという関数を設定していますが、これはLux.jlに入っているイニシャライザーです。

次に、RBMの内部パラメータの初期値を決めます。

function LuxCore.initialparameters(rng::AbstractRNG, l::RBM{F1,F2,F3,T}) where {F1,F2,F3,T}

    return (a=l.init_a(rng, T, l.numspin),
        b=l.init_b(rng, T, l.numhidden),
        W=l.init_W(rng, T, l.numhidden, l.numspin))

end

LuxCore.initialstates(::AbstractRNG, ::RBM) = NamedTuple()

ここで、initialparametersは訓練可能なパラメータの設定、initialstatesは訓練しないパラメータの設定です。今回は訓練可能なパラメータしかありませんので、parametersだけ設定します。

あとは、パラメータの総数を

LuxCore.parameterlength(l::RBM) = l.numspin + l.numhidden + l.numspin * l.numhidden
LuxCore.statelength(::RBM) = 0

とします。

そして、制限ボルツマンマシンの挙動を定義します。

function (m::RBM)(x::AbstractVector, ps, st::NamedTuple)
    factor = exp(sum(x .* ps.a))
    Wx = ps.W * x
    for i = 1:m.numhidden
        factor *= 2 * cosh(ps.b[i] + Wx[i])
    end
    return real(factor)
end

これをテストしてみると、

using Lux
using LuxCore
using Random
using Zygote

function test()
    numspin = 2
    x = rand([-1, 1], numspin)
    rbm = RBM(numspin, 6, ComplexF64)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = LuxCore.setup(rng, rbm)
    display(rbm)
    println(rbm(x, ps, st))
    grad = gradient(ps -> rbm(x, ps, st), ps)[1]
    println(grad)
end
test()

を実行して、

RBM{typeof(glorot_uniform), typeof(glorot_uniform), typeof(glorot_uniform), ComplexF64}(2, 6, WeightInitializers.glorot_uniform, WeightInitializers.glorot_uniform, WeightInitializers.glorot_uniform)  # 20 parameters
-144.4147833615262
(a = ComplexF64[144.4147833615262 + 327.0677862848041im, 144.4147833615262 + 327.0677862848041im], b = ComplexF64[177.1774530611598 + 452.6824194197753im, -23.153466087078577 - 91.69340514608339im, 185.3465585202336 + 354.45093199925736im, 750.5387697492828 + 639.6269505220939im, 351.4306578539614 + 210.04766008630003im, 134.83807978877388 + 273.67545146379126im], W = ComplexF64[-177.1774530611598 - 452.6824194197753im -177.1774530611598 - 452.6824194197753im; 23.153466087078577 + 91.69340514608339im 23.153466087078577 + 91.69340514608339im; -185.3465585202336 - 354.45093199925736im -185.3465585202336 - 354.45093199925736im; -750.5387697492828 - 639.6269505220939im -750.5387697492828 - 639.6269505220939im; -351.4306578539614 - 210.04766008630003im -351.4306578539614 - 210.04766008630003im; -134.83807978877388 - 273.67545146379126im -134.83807978877388 - 273.67545146379126im])

となります。ここで、Flux.jlと違って面白いのは、モデルがパラメータを持っていないために、訓練パラメータを容易に複素数にすることができる、という点です。ComplexF64と指定したため、複素数でパラメータが初期化されています。ここをFloat32にすれば単精度実数になったりします。

エネルギー最小化

次に、量子スピン系のハミルトニアンを定義して、エネルギーを求めて最小化します。ここは前の記事と全く同じなので、解説はそちらをみてください。コードだけ書くと、

using Lux
using LuxCore
using Random
using Zygote
using Optimisers
using LinearAlgebra

const σx = [
    0 1
    1 0
]
const σy = [
    0 -im
    im 0
]
const σz = [
    1 0
    0 -1
]
const σ0 = [
    1 0
    0 1
]

function SziSzjmatrix(i, j, numspin)
    hi = zeros(Float64, 2, 2)
    hj = zeros(Float64, 2, 2)
    for isite = 1:numspin
        if isite == i
            hj .= σz
        elseif isite == j
            hj .= σz
        else
            hj .= σ0
        end
        if isite == 1
            hi .= hj
        else
            hi = kron(hi, hj)
        end
    end
    return hi
end


function Sximatrix(i, numspin)
    hi = zeros(Float64, 2, 2)
    hj = zeros(Float64, 2, 2)
    for isite = 1:numspin
        if isite == i
            hj .= σx
        else
            hj .= σ0
        end
        if isite == 1
            hi .= hj
        else
            hi = kron(hi, hj)
        end
    end
    return hi
end

function make_Hamiltonian(J, hx, numspin)
    H = zeros(Float64, 2^numspin, 2^numspin)
    for i = 1:numspin
        j = i + 1
        j += ifelse(j > numspin, -numspin, 0)
        H += SziSzjmatrix(i, j, numspin) * J
        H += Sximatrix(i, numspin) * hx
    end
    return H
end

function get_S(istate, numspin)
    Sj = zeros(Int8, numspin)
    k = lpad(string(istate - 1, base=2), numspin, "0")
    for ispin = 1:numspin
        Sj[ispin] = ifelse(k[ispin] == '0', -1, 1)
    end
    return Sj
end


function compute_energy(model, ps, st, (x,), H)
    vecψ = [model(x[i], ps, st) for i = 1:length(x)]
    #vecψ, st_ = model.(x, ps, st)
     = H * vecψ
    E = dot(vecψ, ) / dot(vecψ, vecψ)
    return E, st, (; y_pred=E)
end

function training_full!(H, rbm, ps, st, opt, vjp_rule, nt)
    numspin = rbm.numspin
    train_state = Training.TrainState(rbm, ps, st, opt)
    compute_energy_H(model, ps, st, (x,)) = compute_energy(model, ps, st, (x,), H)

    Ss = [get_S(istate, numspin) for istate in 1:2^numspin]

    for it = 1:nt
        (_, E, _, train_state) = Training.single_train_step!(
            vjp_rule, compute_energy_H, (Ss,), train_state)
        println("$it energy: $E")
    end
    return ps, st
end

struct Term{nspin}
    term::Vector{Int8}
    value::Ref{Float64}
end

function Term(nspin)
    term = Array{Int8}(undef, nspin)
    return Term{nspin}(term, 0)
end

function Sxiterm(nspin, i::T, value) where {T<:Integer}
    Sxi = Term(nspin)
    Sxi.value[] = value
    for isite = 1:nspin
        if isite == i
            Sxi.term[isite] = 1
        else
            Sxi.term[isite] = 0
        end
    end
    return Sxi
end

function SziSzjterm(nspin, i, j, value)
    SziSzj = Term(nspin)
    SziSzj.value[] = value
    for isite = 1:nspin
        if isite == i || isite == j
            SziSzj.term[isite] = 3
        else
            SziSzj.term[isite] = 0
        end
    end
    return SziSzj
end

struct Hamiltonian{nspin}
    terms::Vector{Term{nspin}}
end

function Base.length(h::Hamiltonian)
    return length(h.terms)
end

function Hamiltonian(nspin)
    terms = Array{Term{nspin}}(undef, 0)
    return Hamiltonian{nspin}(terms)
end

function Base.push!(h::Hamiltonian, term::Term)
    push!(h.terms, term)
end

function get_nonzero_index(term::Term{nspin}, S, Sj) where {nspin}
    Sj .= S
    value = term.value[]
    for isite = 1:nspin
        kind = term.term[isite]
        if kind == 0
        elseif kind == 1
            Sj[isite] *= -1
        elseif kind == 3
            value *= ifelse(Sj[isite] == 1, 1, -1)
        end
    end
    return value
end

function get_nonzero_indices!(h::Hamiltonian, Ss, values, S)
    numterms = length(h)
    for ikind = 1:numterms
        term = h.terms[ikind]
        values[ikind] = get_nonzero_index(term, S, Ss[ikind])
    end
    return numterms
end

function print_wavefunction(rbm, ps, st)
    numspin = rbm.numspin
    Sj = zeros(Int8, numspin)
    ψjs = []
    Sjs = []
    for i = 0:2^numspin-1
        k = lpad(string(i, base=2), numspin, "0")
        #println(k)
        for ispin = 1:numspin
            Sj[ispin] = ifelse(k[ispin] == '0', -1, 1)
        end
        ψj = rbm(Sj, ps, st)
        push!(ψjs, ψj)
        push!(Sjs, copy(Sj))
        #println("$Sj $ψj  ")
    end
    ψjs /= norm(ψjs)
    for i = 0:2^numspin-1
        Sj = Sjs[i+1]
        ψj = ψjs[i+1]
        println("$Sj $ψj  ")
    end
end

となります。新しく書いた関数は

function compute_energy(model, ps, st, (x,), H)
    vecψ = [model(x[i], ps, st) for i = 1:length(x)]
    #vecψ, st_ = model.(x, ps, st)
     = H * vecψ
    E = dot(vecψ, ) / dot(vecψ, vecψ)
    return E, st, (; y_pred=E)
end

です。モデルにはパラメータを流し込む必要があるため、引数がxとpsとstとなっています。

訓練ですが、

function training_full!(H, rbm, ps, st, opt, vjp_rule, nt)
    numspin = rbm.numspin
    train_state = Training.TrainState(rbm, ps, st, opt)
    compute_energy_H(model, ps, st, (x,)) = compute_energy(model, ps, st, (x,), H)

    Ss = [get_S(istate, numspin) for istate in 1:2^numspin]

    for it = 1:nt
        (_, E, _, train_state) = Training.single_train_step!(
            vjp_rule, compute_energy_H, (Ss,), train_state)
        println("$it energy: $E")
    end
    return ps, st
end

としました。Lux.jlの前の記事にありますように、rbmとpsとstと、最適化方法optと微分の方法vjp_ruleを指定することで、簡単に最小化できます。

出力すると、

3000 energy: -2.2360679774997894
Int8[-1, -1] 0.16245984811645353  
Int8[-1, 1] -0.6881909602355878  
Int8[1, -1] -0.6881909602355857  
Int8[1, 1] 0.16245984811645311  
energy [-2.236067977499788, -2.0000000000000004, 2.0000000000000004, 2.23606797749979]

のようになります。ちゃんと、最小エネルギーが出ていることがわかります。
なお、パラメータを実数に限ると、最小値にはならなかったりします。これは波動関数が正に限られてしまうからです。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?