LoginSignup
19
10

More than 3 years have passed since last update.

Nx,Axonで始めるゼロから作るディープラーニング 7章畳み込みニューラルネットワーク

Last updated at Posted at 2021-04-11

はじめに

本記事はElixirで機械学習/ディープラーニングができるようになるnumpy likeなライブラリ Nxを使って
ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装
をElixirで書いていこうという記事になります。

今回は心が折れたのと、Nxで実装されたDeepLearningフレームワークのAxonが発表されたのでそれを使ってシンプルなCNNを実装していきます

書籍もtrainerとかブラックボックスなim2colとか出してきてるしもういいよね・・・?

準備編
exla setup
1章 pythonの基本 -> とばします
2章 パーセプトロン -> とばします
3章 ニューラルネットワーク
with exla
4章 ニューラルネットワークの学習
5章 誤差逆伝播法
Nx.Defn.Kernel.grad
6章 学習に関するテクニック -> とばします
7章 畳み込みニューラルネットワーク

Axon とは

Nxで書かれたNNライブラリでKerasのようなインターフェースを持っています

model構築

denseのweightの初期値はglorot_uniformで生成されます
https://github.com/elixir-nx/axon/blob/e450f32416179baf818c2948e082932b475d8ed9/lib/axon.ex#L126
第2引数はニューロンの数(units)です

model =
  Axon.input({nil, 784})
  |> Axon.dense(128)
  |> Axon.dense(10, activation: :softmax)

学習

trainerがあるので2行で簡単!
callbackやmetricsオプションはまだみたいですがroadmapに含まれているので気長に待ちましょう
それかTraining.trainのコードをコピーしてmetricsやcallbackを独自に組み込みましょう!

model =
  Axon.input({nil, 784})
  |> Axon.dense(128)
  |> Axon.dense(10, activation: :softmax)

trained_params =
  model
  |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
  |> Axon.Training.train(x_test, t_test, epochs: 10, compiler: EXLA)

推論

学習したパラメーターとモデルを渡して推論します

Axon.predict(model, trained_params, input, compiler: EXLA)

import/export

roadmapに含まれてますが未実装ですので、学習済みパラメータをNx.to_flat_listでlistにしてcsvに書き込むなりetsにぶっこむなりするといいでしょう

Layer

Axon.{Dence, Conv}等重み初期化等がついているのもありますが、単純に値の動きだけを見たい場合は
Axon.Layers以下にあるのでそちらを使いましょう
https://github.com/elixir-nx/axon/blob/main/lib/axon/layers.ex

Axon のインストール

nx,exlaがbuildできているなら問題なくインストールできます

deps.exs
defmodule NxDl.MixProject do
  ...
  def deps do
    [
      {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}, #追加
      {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla", override: true},
      {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
      {:expyplot, "~> 1.1.2"},
      {:erlport, "~> 0.9.8" },
      {:benchee, "~> 1.0", only: :dev}
    ]
  end
end

まだ0.1.0-devで更新が頻繁に起こっているので、github経由のはdeps.update --allを実行しましょう

mix deps.get
mix deps.update --all
mix deps.compile

7.5 CNNの実装

学習データ

cnnのinputはデータ数,チャンネル数、幅、高さの4次元データである必要があるのでreshpeでチャンネル数の項目を追加します

simple_conv_net.ex
  def train_inputs do
    x_train =
      Dataset.train_image
      |> Nx.tensor
      |> Nx.reshape({60000, 1, 28, 28})
      |> (& Nx.divide(&1, Nx.reduce_max(&1))).()
      |> Nx.to_batched_list(100)

    t_train =
      Dataset.train_label
      |> Dataset.to_one_hot
      |> Nx.tensor
      |> Nx.to_batched_list(100)
    {x_train, t_train}
  end

  def test_inputs do
    x_test =
      Dataset.test_image
      |> Nx.tensor
      |> Nx.reshape({10000,1,28,28})
      |> (& Nx.divide(&1, Nx.reduce_max(&1))).()

    t_test = Dataset.test_label |> Dataset.to_one_hot |> Nx.tensor
    {x_test, t_test}
  end

model構築

畳み込み層1、Pooling層1のCNNなので以下のようになります
Axon.conv はunits数の箇所がfilterの数に当たります
他のパラメーターはデフォルトでstride 1, padding valid(0)となっています
書籍とは違ってPooling層のあとにFlatten層でcnnの形式のshapeをDenseで扱えるようにしています

simple_conv_net.ex
  def model do
    Axon.input({nil,1,28,28})
    |> Axon.conv(30, kernel_size: {5, 5}, activation: :relu)
    |> Axon.max_pool(kernel_size: {2, 2})
    |> Axon.flatten()
    |> Axon.dense(100, activation: :relu)
    |> Axon.dense(10, activation: :softmax)
  end

train & test

trainはstepでlossとoptimizerをセットして
trainでepoch数とcompilerをセットしています

testはまだtrainにmetricsの実行が組み込まれていないので、trained_paramsを引数にpredictを行ってその結果を正解ラベルと比較しています
またpredictを実行する場合はrequire Axonを書く必要があります
https://github.com/elixir-nx/axon/blob/main/examples/xor.exs

simple_conv_net.ex
  require Axon

  def train do
    {x_train, t_train} = train_inputs()
    {trained_params, _optmizer} =
      model()
      |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
      |> Axon.Training.train(x_train, t_train, epochs: 10, compiler: EXLA)

    trained_params
  end

  def test(params) do
    {x_test, t_test} = test_inputs()
    Axon.predict(model(), params, x_test, compiler: EXLA)
    |> Axon.Metrics.accuracy(t_test)
  end
iex(1)> params = SimpleConvNet.train                                                        
Epoch 1, batch 600 of 600 - Average Loss: 4.9356760978698735

Epoch 1 Time: 9.746866s
Epoch 1 Loss: 4.935676097869873


Epoch 2, batch 600 of 600 - Average Loss: 4.6949691772460945

Epoch 2 Time: 9.08468s
Epoch 2 Loss: 4.694969177246094


Epoch 3, batch 600 of 600 - Average Loss: 4.6642417907714845

Epoch 3 Time: 9.071895s
Epoch 3 Loss: 4.664241790771484


Epoch 4, batch 600 of 600 - Average Loss: 4.6500945091247565

Epoch 4 Time: 9.131318s
Epoch 4 Loss: 4.650094509124756


Epoch 5, batch 600 of 600 - Average Loss: 4.6417756080627445

Epoch 5 Time: 9.053126s
Epoch 5 Loss: 4.641775608062744


Epoch 6, batch 600 of 600 - Average Loss: 4.6354269981384285

Epoch 6 Time: 9.057393s
Epoch 6 Loss: 4.635426998138428


Epoch 7, batch 600 of 600 - Average Loss: 4.6314291954040535

Epoch 7 Time: 9.047488s
Epoch 7 Loss: 4.631429195404053


Epoch 8, batch 600 of 600 - Average Loss: 4.6279087066650395

Epoch 8 Time: 9.081987s
Epoch 8 Loss: 4.627908706665039


Epoch 9, batch 600 of 600 - Average Loss: 4.6265153884887695

Epoch 9 Time: 9.098026s
Epoch 9 Loss: 4.6265153884887695


Epoch 10, batch 600 of 600 - Average Loss: 4.6234269142150885

Epoch 10 Time: 9.087951s
Epoch 10 Loss: 4.623426914215088

iex(2)> SimpleConvNet.test(params)  
#Nx.Tensor<
  f32
  0.9879000186920166
>

7.6 CNNの可視化

Nx.to_heatmapでみれますが、matrexと違って出力を貼っても表示れなくてお見せできないのが残念です

trained_params |> Tuple.to_list |> List.first |> Nx.to_heatmap

EXLAをGPU modeで動かす

config.exsでexlaでcudaを使うように指定すると hostからcudaを使うようになります

config.exs
use Mix.Config

config :exla, :clients, default: [platform: :cuda], cuda: [platform: :cuda]

最後に

実装は以上になります
Axonはとてもスッキリしてて書いていて気持ちいいです
examplesもresnet50やfashionmnist autoencoder,mnist GANなど結構充実していて参考になります
RNN系やAttentionのlayerが実装されてTransformerとかが使えるようになるのが楽しみです!(使いこなせるかどうかは置いておいて)

最後まで読んでいただきありがとうございました

今回のコード
https://github.com/thehaigo/nx_dl/commit/65b920796c55fbaf7db397515ce3be3ff3b8b750

参考ページ

https://twitter.com/sean_moriarity/status/1380124787318665218
https://seanmoriarity.com/2021/04/08/axon-deep-learning-in-elixir/
https://github.com/elixir-nx/axon
https://github.com/elixir-nx/axon/blob/e450f32416179baf818c2948e082932b475d8ed9/lib/axon.ex#L126
https://github.com/elixir-nx/axon/blob/main/lib/axon/layers.ex
https://github.com/elixir-nx/axon/blob/main/examples/xor.exs
https://github.com/elixir-nx/axon/blob/main/examples/cifar10.exs

19
10
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
10