LoginSignup
1
0

JuliaでDeep Learningを理解する: 4.3 Deep neural networks

Posted at

Understanding Deep LearningノートブックをJuliaで確認する。

4.3 Deep neural networks

JuliaでDeep Learningを理解する: 4.2 Clipping functionsこれまでの記事で見てきたように、Deep Learningは線形関数の合成なので、行列の計算をとにかく行うことになる。また、層ごとに独立して計算できる。

行列で多層ニュートラルネットワークを計算する

以下のような多層ニュートラルネットワークを考える。
DeepKLayer.png
(画像はudlbookより引用)
Juliaは標準で行列の計算をサポートしているので、ライブラリ等の追加は必要ない。

ReLU(z) = ifelse(z < 0, zero(z), z)
D_i = 3
D_1 = 4
D_2 = 2
D_3 = 3
D_o = 2
n_data = 10
x = rand(Float32, (D_i, n_data))
beta_0 = rand(Float32, D_1)
Omega_0 = rand(Float32, (D_1, D_i))
h1 = ReLU.(beta_0 .+ Omega_0 * x)
beta_1 = rand(Float32, D_2)
Omega_1 = rand(Float32, (D_2, D_1))
h2 = ReLU.(beta_1 .+ Omega_1 * h1)
beta_2 = rand(Float32, D_3)
Omega_2 = rand(Float32, (D_3, D_2))
h3 = ReLU.(beta_2 .+ Omega_2 * h2)
beta_3 = rand(Float32, D_o)
Omega_3 = rand(Float32, (D_o, D_3))
y = ReLU.(beta_3 .+ Omega_3 * h3)

n_data = 10はバッチサイズであり、3次元のインプットを10セット一気に計算している。

julia> y = ReLU.(beta_3 .+ Omega_3 * h3)
2×10 Matrix{Float32}:
 3.08755  3.31505  3.89058  3.42552  3.03049  3.14105  3.44985  3.19263  3.23905  3.31655
 4.46442  4.8368   5.7742   5.01455  4.37189  4.55211  5.0533   4.63697  4.71287  4.8374

正しく計算できている。

Flux.jlで多層ニュートラルネットワークを計算する

Flux.jlはJulia純正の深層学習用パッケージである。中身がCということはないので、全てのコードを追うことができる。プログラマーとしては1つの言語で完結しているのは非常にありがたい。JuliaはLLVMにコンパイルされるので、速度も問題ない。

using Flux
model = Chain(
    Dense(D_i => D_1, relu),
    Dense(D_1 => D_2, relu),
    Dense(D_2 => D_3, relu),
    Dense(D_3 => D_o)
)
model(x)

Denseは全結合層であり、活性化関数をreluに指定している(ちなみに自分で今回作ったReLU関数を指定しても問題なく動作する)。Chainで各層を1まとめにしてmodel(x)で計算する。

julia> model(x)
2×10 Matrix{Float32}:
 0.0772406  0.00252384  0.0843285  0.145631  0.0692204  0.064958  0.190816  0.0247252  0.0205282  0.119318
 0.900612   0.942941    1.41732    1.20932   0.75046    0.86907   1.32434   0.835291   0.834123   1.06254

乱数を使っているので値は異なるが正しい大きさのyが得られている。
パッケージ内で行なっている計算は自前で行った実装

h1 = ReLU.(beta_0 .+ Omega_0 * x)

とほぼ違いはなく

σ.(a.weight * xT .+ a.bias)

となっている。σReLUbeta_0a.biasa.weightOmega_0xxTである。もちろんライブラリなので汎用性のためのコードや型安定性のためのコードはあるが、実装のコアの部分は変わらない。
Juliaの言語仕様に素直に短くコード書けばそれが速度的にも良い書き方になっているので安心してコードを書けるのがJuliaの好きなところだ。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0