Juliaでの機械学習ライブラリといえばFlux.jlが有名ですが、Lux.jlというものもあります。Lux.jlとFlux.jlの違いについては、Juliaで機械学習:Flux.jlではなくLux.jlを使ってみる
で述べました。
簡単に違いをまとめると、
-
Flux.jl: 機械学習パッケージ。これ一つで何でもできる。モデルを作ると、そのモデルには学習可能なパラメータが存在している(ユーザーからパラメータが見えない)。「モデルを更新」して機械学習を行う。
-
Lux.jl: 機械学習のモデルに関するパッケージ。他のパッケージと組み合わせて学習を行う。モデルを作ると、モデルの骨組みだけが作られ、モデルは学習可能なパラメータを持っていない。モデルにパラメータを流し込むことでモデルが値を出力する。「パラメータを更新」して機械学習を行う。
使い方の違いは、
-
Flux.jl: Flux.jlのデフォルトに入っている機能であれば簡単に使える。メジャーなものは何でも揃っている。パラメータを取り出して更新することも可能。基本的に何でもできる
-
Lux.jl: パラメータを別に扱うことで、更新方法を他のパッケージに丸投げすることができる。例えばNeuralODEであれば、微分方程式を数値的に解くパッケージを途中に入れることでパラメータを更新することができる。他のパッケージと組み合わせやすい形になっている
となっています。
さて、Lux.jlでの簡単な例は前の記事で書きましたので、今度はより複雑なモデルを考えてみます。
制限ボルツマンマシン
制限ボルツマンマシンについては、前の記事Juliaを使って制限ボルツマンマシンで量子スピン系の基底状態を求めるを参照してください。
波動関数$\psi(s_1,\cdots,s_N)$の構成を考えます。ここでは、制限ボルツマンマシン(RBM)を使います。RBMは、
\psi(s_1,\cdots,s_N) = \sum_{h_1=\pm 1,\cdots,h_M=\pm 1} \exp \left[ \sum_{j=1}^N a_j s_j + \sum_{i=1}^M b_i h_i + \sum_{i,j} W_{ij} h_i s_j \right]
という関数です(Carleo et al., Science 355,602–606 (2017))。ここで、$h_1,\cdots,h_M$という$M$個の変数は隠れ層と呼ばれるものです。この模型の良いところは、$h_1,\cdots,h_M$という和を解析的に実行できる点でして、
\psi(s_1,\cdots,s_N) = e^{\sum_{j=1}^N a_j s_j} \prod_{i=1}^M 2 \cosh \left( b_i + \sum_j W_{ij} s_j \right)
という$h_1,\cdots,h_M$を含まない形に変形できます。この関数のパラメータは、$a_j,b_i,W_{ij}$です。
RBMで気をつけなければならない点としては、この関数は、パラメータが実数である限り、常に正であることです。もし波動関数が負になり得る場合には、パラメータを複素数にする必要があります。
Lux.jlでのモデリング
それでは、Lux.jlでRBMを作っていきましょう。Flux.jlとの違いは、モデルそのものにはデータを持たない、ということです。Lux.jlのモデルは、骨組みだけを持ちます。
まず、RBMという型を
struct RBM{F1,F2,F3,T} <: Lux.AbstractLuxLayer
numspin::Int
numhidden::Int
init_a::F1
init_b::F2
init_W::F3
end
と定義します。ポイントは、学習するべきパラメータを含まない、ということです。スピンの数や、初期化の方法だけを保持しています。
次に、コンストラクタを
function RBM(numspin, numhidden, T=Float64,
init_a=glorot_uniform, init_b=glorot_uniform, init_W=glorot_uniform)
return RBM{typeof(init_a),typeof(init_b),typeof(init_W),T}(numspin, numhidden, init_a, init_b, init_W)
end
とします。ここで、初期化のデフォルトとしてglorot_uniform
という関数を設定していますが、これはLux.jlに入っているイニシャライザーです。
次に、RBMの内部パラメータの初期値を決めます。
function LuxCore.initialparameters(rng::AbstractRNG, l::RBM{F1,F2,F3,T}) where {F1,F2,F3,T}
return (a=l.init_a(rng, T, l.numspin),
b=l.init_b(rng, T, l.numhidden),
W=l.init_W(rng, T, l.numhidden, l.numspin))
end
LuxCore.initialstates(::AbstractRNG, ::RBM) = NamedTuple()
ここで、initialparameters
は訓練可能なパラメータの設定、initialstates
は訓練しないパラメータの設定です。今回は訓練可能なパラメータしかありませんので、parametersだけ設定します。
あとは、パラメータの総数を
LuxCore.parameterlength(l::RBM) = l.numspin + l.numhidden + l.numspin * l.numhidden
LuxCore.statelength(::RBM) = 0
とします。
そして、制限ボルツマンマシンの挙動を定義します。
function (m::RBM)(x::AbstractVector, ps, st::NamedTuple)
factor = exp(sum(x .* ps.a))
Wx = ps.W * x
for i = 1:m.numhidden
factor *= 2 * cosh(ps.b[i] + Wx[i])
end
return real(factor)
end
これをテストしてみると、
using Lux
using LuxCore
using Random
using Zygote
function test()
numspin = 2
x = rand([-1, 1], numspin)
rbm = RBM(numspin, 6, ComplexF64)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = LuxCore.setup(rng, rbm)
display(rbm)
println(rbm(x, ps, st))
grad = gradient(ps -> rbm(x, ps, st), ps)[1]
println(grad)
end
test()
を実行して、
RBM{typeof(glorot_uniform), typeof(glorot_uniform), typeof(glorot_uniform), ComplexF64}(2, 6, WeightInitializers.glorot_uniform, WeightInitializers.glorot_uniform, WeightInitializers.glorot_uniform) # 20 parameters
-144.4147833615262
(a = ComplexF64[144.4147833615262 + 327.0677862848041im, 144.4147833615262 + 327.0677862848041im], b = ComplexF64[177.1774530611598 + 452.6824194197753im, -23.153466087078577 - 91.69340514608339im, 185.3465585202336 + 354.45093199925736im, 750.5387697492828 + 639.6269505220939im, 351.4306578539614 + 210.04766008630003im, 134.83807978877388 + 273.67545146379126im], W = ComplexF64[-177.1774530611598 - 452.6824194197753im -177.1774530611598 - 452.6824194197753im; 23.153466087078577 + 91.69340514608339im 23.153466087078577 + 91.69340514608339im; -185.3465585202336 - 354.45093199925736im -185.3465585202336 - 354.45093199925736im; -750.5387697492828 - 639.6269505220939im -750.5387697492828 - 639.6269505220939im; -351.4306578539614 - 210.04766008630003im -351.4306578539614 - 210.04766008630003im; -134.83807978877388 - 273.67545146379126im -134.83807978877388 - 273.67545146379126im])
となります。ここで、Flux.jlと違って面白いのは、モデルがパラメータを持っていないために、訓練パラメータを容易に複素数にすることができる、という点です。ComplexF64
と指定したため、複素数でパラメータが初期化されています。ここをFloat32
にすれば単精度実数になったりします。
エネルギー最小化
次に、量子スピン系のハミルトニアンを定義して、エネルギーを求めて最小化します。ここは前の記事と全く同じなので、解説はそちらをみてください。コードだけ書くと、
using Lux
using LuxCore
using Random
using Zygote
using Optimisers
using LinearAlgebra
const σx = [
0 1
1 0
]
const σy = [
0 -im
im 0
]
const σz = [
1 0
0 -1
]
const σ0 = [
1 0
0 1
]
function SziSzjmatrix(i, j, numspin)
hi = zeros(Float64, 2, 2)
hj = zeros(Float64, 2, 2)
for isite = 1:numspin
if isite == i
hj .= σz
elseif isite == j
hj .= σz
else
hj .= σ0
end
if isite == 1
hi .= hj
else
hi = kron(hi, hj)
end
end
return hi
end
function Sximatrix(i, numspin)
hi = zeros(Float64, 2, 2)
hj = zeros(Float64, 2, 2)
for isite = 1:numspin
if isite == i
hj .= σx
else
hj .= σ0
end
if isite == 1
hi .= hj
else
hi = kron(hi, hj)
end
end
return hi
end
function make_Hamiltonian(J, hx, numspin)
H = zeros(Float64, 2^numspin, 2^numspin)
for i = 1:numspin
j = i + 1
j += ifelse(j > numspin, -numspin, 0)
H += SziSzjmatrix(i, j, numspin) * J
H += Sximatrix(i, numspin) * hx
end
return H
end
function get_S(istate, numspin)
Sj = zeros(Int8, numspin)
k = lpad(string(istate - 1, base=2), numspin, "0")
for ispin = 1:numspin
Sj[ispin] = ifelse(k[ispin] == '0', -1, 1)
end
return Sj
end
function compute_energy(model, ps, st, (x,), H)
vecψ = [model(x[i], ps, st) for i = 1:length(x)]
#vecψ, st_ = model.(x, ps, st)
Hψ = H * vecψ
E = dot(vecψ, Hψ) / dot(vecψ, vecψ)
return E, st, (; y_pred=E)
end
function training_full!(H, rbm, ps, st, opt, vjp_rule, nt)
numspin = rbm.numspin
train_state = Training.TrainState(rbm, ps, st, opt)
compute_energy_H(model, ps, st, (x,)) = compute_energy(model, ps, st, (x,), H)
Ss = [get_S(istate, numspin) for istate in 1:2^numspin]
for it = 1:nt
(_, E, _, train_state) = Training.single_train_step!(
vjp_rule, compute_energy_H, (Ss,), train_state)
println("$it energy: $E")
end
return ps, st
end
struct Term{nspin}
term::Vector{Int8}
value::Ref{Float64}
end
function Term(nspin)
term = Array{Int8}(undef, nspin)
return Term{nspin}(term, 0)
end
function Sxiterm(nspin, i::T, value) where {T<:Integer}
Sxi = Term(nspin)
Sxi.value[] = value
for isite = 1:nspin
if isite == i
Sxi.term[isite] = 1
else
Sxi.term[isite] = 0
end
end
return Sxi
end
function SziSzjterm(nspin, i, j, value)
SziSzj = Term(nspin)
SziSzj.value[] = value
for isite = 1:nspin
if isite == i || isite == j
SziSzj.term[isite] = 3
else
SziSzj.term[isite] = 0
end
end
return SziSzj
end
struct Hamiltonian{nspin}
terms::Vector{Term{nspin}}
end
function Base.length(h::Hamiltonian)
return length(h.terms)
end
function Hamiltonian(nspin)
terms = Array{Term{nspin}}(undef, 0)
return Hamiltonian{nspin}(terms)
end
function Base.push!(h::Hamiltonian, term::Term)
push!(h.terms, term)
end
function get_nonzero_index(term::Term{nspin}, S, Sj) where {nspin}
Sj .= S
value = term.value[]
for isite = 1:nspin
kind = term.term[isite]
if kind == 0
elseif kind == 1
Sj[isite] *= -1
elseif kind == 3
value *= ifelse(Sj[isite] == 1, 1, -1)
end
end
return value
end
function get_nonzero_indices!(h::Hamiltonian, Ss, values, S)
numterms = length(h)
for ikind = 1:numterms
term = h.terms[ikind]
values[ikind] = get_nonzero_index(term, S, Ss[ikind])
end
return numterms
end
function print_wavefunction(rbm, ps, st)
numspin = rbm.numspin
Sj = zeros(Int8, numspin)
ψjs = []
Sjs = []
for i = 0:2^numspin-1
k = lpad(string(i, base=2), numspin, "0")
#println(k)
for ispin = 1:numspin
Sj[ispin] = ifelse(k[ispin] == '0', -1, 1)
end
ψj = rbm(Sj, ps, st)
push!(ψjs, ψj)
push!(Sjs, copy(Sj))
#println("$Sj $ψj ")
end
ψjs /= norm(ψjs)
for i = 0:2^numspin-1
Sj = Sjs[i+1]
ψj = ψjs[i+1]
println("$Sj $ψj ")
end
end
となります。新しく書いた関数は
function compute_energy(model, ps, st, (x,), H)
vecψ = [model(x[i], ps, st) for i = 1:length(x)]
#vecψ, st_ = model.(x, ps, st)
Hψ = H * vecψ
E = dot(vecψ, Hψ) / dot(vecψ, vecψ)
return E, st, (; y_pred=E)
end
です。モデルにはパラメータを流し込む必要があるため、引数がxとpsとstとなっています。
訓練ですが、
function training_full!(H, rbm, ps, st, opt, vjp_rule, nt)
numspin = rbm.numspin
train_state = Training.TrainState(rbm, ps, st, opt)
compute_energy_H(model, ps, st, (x,)) = compute_energy(model, ps, st, (x,), H)
Ss = [get_S(istate, numspin) for istate in 1:2^numspin]
for it = 1:nt
(_, E, _, train_state) = Training.single_train_step!(
vjp_rule, compute_energy_H, (Ss,), train_state)
println("$it energy: $E")
end
return ps, st
end
としました。Lux.jlの前の記事にありますように、rbmとpsとstと、最適化方法optと微分の方法vjp_ruleを指定することで、簡単に最小化できます。
出力すると、
3000 energy: -2.2360679774997894
Int8[-1, -1] 0.16245984811645353
Int8[-1, 1] -0.6881909602355878
Int8[1, -1] -0.6881909602355857
Int8[1, 1] 0.16245984811645311
energy [-2.236067977499788, -2.0000000000000004, 2.0000000000000004, 2.23606797749979]
のようになります。ちゃんと、最小エネルギーが出ていることがわかります。
なお、パラメータを実数に限ると、最小値にはならなかったりします。これは波動関数が正に限られてしまうからです。