5
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Juliaで機械学習(Flux.jl)

Last updated at Posted at 2021-10-05

概要

Juliaはあらゆる言語のいいとこ取りをした新しいプログラミング言語です。イメージは「Cのように高速に動作し、Pythonのようにシンプルに記述できる」です。大量の計算が必要な機械学習分野で最近注目されている言語です。

Juliaの詳細な説明は他の記事に任せます。

今回はJuliaの機械学習ライブラリであるFluxを用いてニューラルネットワークを構築し、手書き数字の分類を行ってみたいと思います。ここでは機械学習の詳細な説明は省きます。

必要なライブラリとMNISTデータセットのインポート

必要なライブラリをインポートし、MNISTデータセットを取得します。

using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated

データセットの取得

# MNISTデータを取得
train_images = Flux.Data.MNIST.images(:train)
train_labels = Flux.Data.MNIST.labels(:train)
test_images = Flux.Data.MNIST.images(:test)
test_labels = Flux.Data.MNIST.labels(:test)

データの前処理

Numpyのreshape関数に相当するhcat関数を使用して、2次元の画像を1次元に変換し、onehotbatch関数を用いてラベルのワンホットエンコーティングを行います。

X_train = hcat(float.(reshape.(train_images, :))...) 
Y_train = onehotbatch(train_labels, 0:9)  

モデルの定義

ニューラルネットワークのモデルを定義します。今回は2つの層を追加し、784個(28×28)の入力を10個の出力に変換します。

model = Chain(Dense(28^2, 32, relu), Dense(32, 10), softmax)
# 出力
Chain(
  Dense(784, 32, relu),                 # 25_120 parameters
  Dense(32, 10),                        # 330 parameters
  NNlib.softmax,
)                   # Total: 4 arrays, 25_450 parameters, 99.664 KiB.

損失関数、オプティマイザーなどを設定

損失関数はクロスエントロピー誤差、オプティマイザーはAdamを使用します。

loss(x, y) = crossentropy(model(x), y) 
optim = ADAM()
evalcb = () -> @show(loss(X_train, Y_train)) 

データセット指定

学習に使用するデータセットを指定します。

dataset = repeated((X_train,Y_train),200) 

repeated((X,Y),Number)については下記のコードをご覧ください。data1、data2は両者ともに意味は同じです。つまり上記のコードの意味は同じデータを200個作るという意味になります。

data1 = [(x, y), (x, y), (x, y)]

data2 = Iterators.repeated((x, y), 3)

学習

学習はFlux.train!(損失関数、モデル、データセット、オプティマイザー、コールバック関数)で実行します。

Flux.train!(loss, params(model), dataset, optim, cb = throttle(evalcb, 10));
# 出力
loss(X_train, Y_train) = 2.306777815490352
loss(X_train, Y_train) = 2.0586748841304185
loss(X_train, Y_train) = 1.8409637861168056
loss(X_train, Y_train) = 1.6375431419826942
loss(X_train, Y_train) = 1.445993643519817
loss(X_train, Y_train) = 1.2296568110673718
loss(X_train, Y_train) = 1.0502570108967053
loss(X_train, Y_train) = 0.9107898094222893
loss(X_train, Y_train) = 0.8218164747663838
loss(X_train, Y_train) = 0.7500740838675154
loss(X_train, Y_train) = 0.6914486287921706
loss(X_train, Y_train) = 0.6434973108945049
loss(X_train, Y_train) = 0.5950358836899775
loss(X_train, Y_train) = 0.5631648753189052
loss(X_train, Y_train) = 0.5297994058656267
loss(X_train, Y_train) = 0.5071319050375515
loss(X_train, Y_train) = 0.48271218625780954
loss(X_train, Y_train) = 0.46571299078246714
loss(X_train, Y_train) = 0.45056021837013205
loss(X_train, Y_train) = 0.4369715972522749
loss(X_train, Y_train) = 0.42182257993938366
loss(X_train, Y_train) = 0.41356730105547634
loss(X_train, Y_train) = 0.4033984460086085
loss(X_train, Y_train) = 0.3940627362778361
loss(X_train, Y_train) = 0.38340295361498883
loss(X_train, Y_train) = 0.37558235668831086
loss(X_train, Y_train) = 0.36656399693056557
loss(X_train, Y_train) = 0.35988036419559005
loss(X_train, Y_train) = 0.35362035413366855
loss(X_train, Y_train) = 0.34774049784963607
loss(X_train, Y_train) = 0.34220475892263397
loss(X_train, Y_train) = 0.3369876883884653
loss(X_train, Y_train) = 0.3320573378626823
loss(X_train, Y_train) = 0.3285298332916139
loss(X_train, Y_train) = 0.32403339593560665
loss(X_train, Y_train) = 0.31975565304929565
loss(X_train, Y_train) = 0.31568244659729194
loss(X_train, Y_train) = 0.3117944546753522
loss(X_train, Y_train) = 0.3080787518320183
loss(X_train, Y_train) = 0.30452654964434067
loss(X_train, Y_train) = 0.30111957889417673
loss(X_train, Y_train) = 0.29784627528038254
loss(X_train, Y_train) = 0.29469883863927654
loss(X_train, Y_train) = 0.2916642419823619
loss(X_train, Y_train) = 0.2880199533894845
loss(X_train, Y_train) = 0.28521705385216156
loss(X_train, Y_train) = 0.2825064992625212
loss(X_train, Y_train) = 0.2792328401482443

性能を評価する

精度はおよそ92%となりました。今回は単純なニューラルネットワークで行いましたが、Fluxは畳み込みニューラルネットワーク(CNN)も使えるためさらなる精度の向上が見込めそうです。

accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
test_X = hcat(float.(reshape.(test_images, :))...)
test_Y = onehotbatch(test_labels, 0:9)

@show accuracy(test_X, test_Y)
# 出力
0.9249
5
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?