2
3

More than 1 year has passed since last update.

JuliaでPyCall.jlを用いてPyTorchを使う

Posted at

すでに以下のようなわかりやすい記事がありますが、流石に月日が流れ、PyCall.jlやPyTorchの方にも仕様変更があったようなので、少し自分でも見てみようと思いました。

(個人的にFlux.jl使ってたんですが、ちょっとZygote.jlではどうにもならない自動微分の案件がありまして......)
↑こういう経緯なので、基本自分用のメモみたいなもので、個人的に必要ないと感じたことは飛ばしますので、ご容赦を。

環境構築は本家https://github.com/JuliaPy/PyCall.jl を見たほうが良いです。僕はすでに入れてあったAnaconda3のPyTorchが入った仮想環境を使っています。

PyCall.jlのキホン

pythonのライブラリはpyimport()で出来ます。

input
using PyCall
torch = pyimport("torch")

ここで定義した変数torchはjuliaではPyObjectという構造体になっています。
pythonライブラリ内で定義されたクラスも

input
nn = torch.nn
F = torch.nn.functional

のように使う事ができます1。これらもPyObjectです。

PyCall.jlでは、一部の代表的な型、例えばjuliaのArrayとpythonのnumpy.ndarrayを自動的に変換してくれたりするのですが、それ以外は一括でPyObjectで扱っています。
例えば

input
A = torch.zeros((2, 3))

とすれば

output
PyObject tensor([[0., 0., 0.],
        [0., 0., 0.]])

となります。

input
A.cuda()

とすればもちろんGPUにも転送できます。

output
PyObject tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')

torch.tensorをjuliaで扱えるようにするにはnumpy.ndarrayに変換すれば、先程書いたようにjulia上では自動的にArrayとして変換されます。

input
A.numpy()
output
2×3 Matrix{Float32}:
 0.0  0.0  0.0
 0.0  0.0  0.0

ここまではかなり直感的だと思います。

pythonはオブジェクト指向なのでクラスも扱える必要があります。ここは少しだけ違っていて、

input
@pydef mutable struct Test
    function __init__(self) 
        self.x = 0
    end
    function add1(self)
        self.x += 1
    end
end

のようにmutable structに@pydefマクロをつける2ことでpythonのクラスに対応するものを定義できます。メソッドは関数として定義すればOKです。一応対応するpythonコードも示しておきます。

python
class Test:
    def __init__(self):
        self.x = 0
    
    def add1(self):
        self.x += 1

実際に確かめると、

input
test = Test()
println(test.x)
test.add1()
println(test.x)
output
0
1

きちんと動いています。

ちなみに継承は<:を使います。

input
@pydef mutable struct Test2 <: Test
    function __init__(self) 
        pybuiltin("super")(Test2, self).__init__()
        # pybuiltin(:super)(Test2, self).__init__() でもOK
    end
    function mult3(self)
        self.x *= 3
    end
end
python
class Test2(Test):
    def __init__(self):
        super().__init__() 
        # same as super(Test2, self).__init__()
    
    def mult3(self):
        self.x *= 3

ここでpythonのsuper().__init__() pybuiltin("super")(Test2, self).__init__()に変わっています。
pybuiltin()は引数の文字列(もしくはシンボル)に対応するpythonの組み込み関数を呼び出す関数です。
pythonのsuper()は引数を省略するとよしなにやってくれるのですが、pybuiltin()は引数を明示しないといけないようです。

input
test2 = Test2()
println(test2.x)
test2.add1()
println(test2.x)
test2.mult3()
println(test2.x)
output
0
1
3

ニューラルネットワーク

本題に入ることにします。難しい例をやっても仕方ないので、ここではレナードジョーンズ型のポテンシャル関数 $U(r) = \epsilon[(\sigma/r)^{12}-(\sigma/r)^6]$をニューラルネットワークで回帰することにします。
上記の内容とPyTorchの扱いがわかっていればほとんど直感的にできると思いますので、説明はあっさり風味にします。

input
using PyCall
using Plots
using Zygote
torch = pyimport("torch")
nn = torch.nn
F = torch.nn.functional
device = torch.device(ifelse(torch.cuda.is_available(), "cuda:0", "cpu"))
output
PyObject device(type='cuda', index=0)

続いてレナードジョーンズポテンシャルを定義します。
データはポテンシャル$U(r)$だけでなくその勾配の符号反転である力$F(r) = -\frac{d}{dr}U(r)$も含めることにします。
力はZygote.jlの微分で出しています。

input
mutable struct LJ
    σ
    ϵ
end

(m::LJ)(r) = m.σ[1] .* (m.ϵ[1] ./ r).^12 - m.σ[1] .* (m.ϵ[1] ./ r).^6
σ_true = [1.0]
ϵ_true = [2.0]
lj = LJ(σ_true,ϵ_true)
lj_force(r) = gradient(r -> -lj(r)[1], r)[1]

(パラメータを配列にしたり、call overloadingしているのはFlux.jl使ってたときの名残なので気にしない)

データ作ってtorch.tensorにしてgpuに流します。

input
r_train = reshape(range(2.0,5.0,31),:,1)
U_train = lj(r_train)
F_train = lj_force.(r_train)

r_train = torch.tensor(r_train,requires_grad=true,dtype=torch.float32).to(device)
F_train = torch.tensor(F_train,dtype=torch.float32).to(device)
U_train = torch.tensor(U_train,dtype=torch.float32).to(device)

ニューラルネットワークの定義をします。

input
@pydef mutable struct Net <: torch.nn.Module
    function __init__(self)
        pybuiltin("super")(Net, self).__init__()
        self.net1 = torch.nn.Linear(1, 5)
        self.net2 = torch.nn.Linear(5, 1)
    end

    function forward(self, x)
        x = self.net1(x)  
        x = torch.sigmoid(x)
        x = self.net2(x)
        return x
    end
end

学習します。損失関数として、ポテンシャルと力の二乗誤差の混合を使います。

input
net = Net()
net.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.99), eps=1e-07)
loss_list = []
epoch = 2000
for i = 1:epoch
    U_pred = net(r_train)
    F_pred = []
    for i = 1:U_pred.shape[1]
        push!(F_pred, - torch.autograd.grad(U_pred[i], r_train, create_graph=true)[1][i])
    end
    F_pred = torch.cat(F_pred).reshape(-1, 1)
    
    loss = criterion(F_pred, F_train) + 10 * criterion(U_pred, U_train)
    loss_item = loss.item()
    println("Epoch: $i, Loss: $loss_item")
    push!(loss_list,loss_item)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
end

学習したモデルを示します。

input
r_pred = reshape(range(2.0,5.0,301),:,1)
r_pred = torch.tensor(r_pred,requires_grad=true,dtype=torch.float32).to(device)
U_pred = net(r_pred)

F_pred = zeros(U_pred.shape[1])
for i = 1:U_pred.shape[1]
    F_pred[i,1] = - torch.autograd.grad(U_pred[i], r_pred, create_graph=true)[1][i]
end

U_pred = U_pred.cpu().detach().numpy()
r_pred = r_pred.cpu().detach().numpy()
U_train = U_train.cpu().detach().numpy()
r_train = r_train.cpu().detach().numpy()
F_train = F_train.cpu().detach().numpy()

表示してみます。

input
scatter(r_train[:], U_train[:],legend=:none)
plot!(r_pred[:], U_pred[:],legend=:none,title="potential")

LJ_U.png

input
scatter(r_train[:], F_train[:],legend=:none)
plot!(r_pred[:], F_pred[:],legend=:none,title="force")

LJ_F.png
よく一致していますね。汎化性能については検証していないですが、学習は出来ていそうです。

追記:DataLoaderの使い方

pythonで扱う場合はDataLoaderは楽に使えるのですが、PyCall.jlを使って呼び出す場合は少し癖があります。
python形式のDataLoaderはjuliaから見ると謎のPyObjectでしかないので、for文で回せないからです。
なので原始的にDataLoaderをiter()でiteratorにしてからnext()で進める、という操作をループさせます3。最後まで行った状態でさらにnext()するとエラーとなるので、try~~catch構文でループを止めます。

input
torch_data_set = data.TensorDataset(r_train, U_train, F_train)
loader = data.DataLoader(dataset=torch_data_set, batch_size=10,shuffle=true,drop_last=true)
loader_iter = pybuiltin("iter")(loader)

while true 
    try
        r_batch, U_batch, F_batch = pybuiltin("next")(loader_iter)
        ### ミニバッチへの処理 ###
    catch
        break
    end
end

という感じです。
わかってしまえば難しくはないですね。

一応DataLoader使ったバージョンも記載しておきます(このレベルのデータ数、次元、関数の複雑さでは、ミニバッチSGDは最適化の単なるノイズでしかないので性能は落ちます)。

input
# dataloader
r_train = reshape(range(2.0,5.0,151),:,1)
U_train = lj(r_train)
F_train = lj_force.(r_train)
r_train = torch.tensor(r_train,requires_grad=true,dtype=torch.float32)
F_train = torch.tensor(F_train,dtype=torch.float32)
U_train = torch.tensor(U_train,dtype=torch.float32)
# r_train = torch.tensor(r_train,requires_grad=true,dtype=torch.float32).to(device)
# F_train = torch.tensor(F_train,dtype=torch.float32).to(device)
# U_train = torch.tensor(U_train,dtype=torch.float32).to(device)
torch_data_set = data.TensorDataset(r_train, U_train, F_train)
loader = data.DataLoader(dataset=torch_data_set, batch_size=30, shuffle=true,drop_last=true)


@pydef mutable struct Net <: torch.nn.Module
    function __init__(self)
        pybuiltin("super")(Net, self).__init__()
        self.net1 = torch.nn.Linear(1, 5)
        self.net2 = torch.nn.Linear(5, 1)
    end

    function forward(self, x)
        x = self.net1(x)  
        x = torch.sigmoid(x)
        x = self.net2(x)
        return x
    end
end
net = Net()
net.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.99), eps=1e-07)
loss_list = []
epoch = 10000
for i = 1:epoch
    loader_iter = pybuiltin("iter")(loader)
    batch_i = 0
    loss_item = 0.0
    while true  
        try
            r_batch, U_batch, F_batch = pybuiltin("next")(loader_iter)
            r_batch = r_batch.requires_grad_(true).to(device)
            U_batch = U_batch.to(device)
            F_batch = F_batch.to(device)
            U_pred = net(r_batch)
            F_pred = []
            for i = 1:U_pred.shape[1]
                push!(F_pred, - torch.autograd.grad(U_pred[i], r_batch, create_graph=true)[1][i])
            end
            F_pred = torch.cat(F_pred).reshape(-1, 1)
            loss = criterion(F_pred, F_batch) + 10 * criterion(U_pred, U_batch)
            loss_item += loss.item()
            batch_i += 1
            optimizer.zero_grad()
            loss.backward()
            optimizer.step() 
        catch
            break
        end
    end    
    loss_item /= batch_i
    println("Epoch: $i, Loss: $loss_item")
    push!(loss_list,loss_item)
end

  1. 由緒正しい書き方として上の記事ではnn = torch[:nn]が使われており、これでもOKです(というより前までこれでないといけなかったようです)。現在はどうやらpythonライクにかけるように改良されたようです。

  2. 上の記事では@pydef type ~~となっていますが、僕の環境ではこれでは動かないです(おそらくこれもPyCall.jlの仕様変更でしょう)。

  3. もっとスマートなやり方があるかもしれません。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3