4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Julia 1.0 + Flux でMNIST学習(CNN版)

Posted at

本日は

Julia 1.0 + FluxでMNIST学習

の延長としてモデルをCNNにしてみます。

実装

いきなり実装です。
手元のPCがGPUはいってない場合は using CuArraysを削除して実行すると動作します。

#=
MNIST sample
Taken from model zoo of Julia and modified.
https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl
=#

using Printf

using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, @epochs
using Base.Iterators: partition
using BSON: @load, @save
using CuArrays
using Random
using Images:channelview

function prepare_dataset(;train=true)
    train_or_test = ifelse(train,:train,:test)
    imgs = MNIST.images(train_or_test)
    labels = MNIST.labels(train_or_test)
    Y = onehotbatch(labels,0:9)
    return imgs, Y
end


function define_model(;hidden::Int)
    mlp = Chain(Conv((3,3), 1=>hidden, relu, pad=1, stride=1),
                MaxPool((2,2)),
                x -> reshape(x, :, size(x, 4)),
                Dense(14*14*8,10),
                softmax)
    return mlp
end

function split_dataset_random(X, Y)
    divide_ratio=0.9
    shuffled_indices = shuffle(1:size(Y)[2])
    divide_idx = round(Int,0.9*length(shuffled_indices))
    train_indices = shuffled_indices[1:divide_idx]
    val_indices = shuffled_indices[divide_idx:end]
    train_X = X[train_indices]
    train_Y = Y[:,train_indices]
    val_X = X[val_indices]
    val_Y = Y[:,val_indices]
    return train_X, train_Y, val_X, val_Y
end

function train()
    println("Start to train")
    epochs = 3
    X, Y = prepare_dataset(train=true)
    train_img, train_Y, val_img, val_Y = split_dataset_random(X, Y)
    model = define_model(hidden=8) |> gpu
    loss(x,y)= crossentropy(model(x),y)
    accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
    batchsize = 64
    train_X = [cat(permutedims(float.(train_img[batch]))...,dims=4) for batch in partition(1:length(train_img),batchsize)]
    val_X   = [cat(permutedims(float.(val_img[batch]))...,dims=4) for batch in partition(1:length(val_img),batchsize)]
    train_dataset = gpu.([(train_X[i] ,train_Y[:,batch]) for (i, batch) in enumerate(partition(1:size(train_Y)[2],batchsize))])
    val_dataset = gpu.([(val_X[i] ,val_Y[:,batch]) for (i,batch) in enumerate(partition(1:size(val_Y)[2],batchsize))])
    callback_count = 0
    eval_callback = function callback()
        callback_count += 1
        if callback_count == length(train_dataset)
            println("action for each epoch")
            total_loss = 0
            total_acc = 0
            for (vx, vy) in val_dataset
                total_loss += loss(vx, vy)
                total_acc += accuracy(vx, vy)
            end
            total_loss /= length(val_dataset)
            total_acc /= length(val_dataset)
            @show total_loss, total_acc
            callback_count = 0
            pretrained = model |> cpu
            @save "pretrained.bson" pretrained
            callback_count = 0
        end
        if callback_count % 50 == 0
            progress = callback_count / length(train_dataset)
           @printf("%.3f\n", progress)
        end
    end
    optimizer = ADAM(params(model))

    @epochs epochs Flux.train!(loss, train_dataset, optimizer, cb = eval_callback)

    pretrained = model |> cpu
    weights = Tracker.data.(params(pretrained))
    @save "pretrained.bson" pretrained
    @save "weights.bson" weights
    println("Finished to train")
end

function predict()
    println("Start to evaluate testset")
    println("loading pretrained model")
    @load "pretrained.bson" pretrained
    model = pretrained |> gpu
    accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
    println("prepare dataset")
    test_img, Y = prepare_dataset(train=false)
    batchsize=64
    test_X = [cat(permutedims(float.(test_img[batch]))...,dims=4) for batch in partition(1:length(test_img),batchsize)]
    test_label = gpu.([Y[:,batch] for (i,batch) in enumerate(partition(1:size(Y)[2],batchsize))])
    total_acc=0
    cnt = 0
    for (x,y) in zip(test_X,test_label)
        total_acc += accuracy(x|> gpu, y|>gpu)
        cnt += 1
    end
    println("acc = $(total_acc/cnt)")
    println("Done")
end

function predict2()
    println("Start to evaluate testset")
    println("loading pretrained model")
    @load "weights.bson" weights
    model = define_model(hidden=8)
    Flux.loadparams!(model, weights)
    model = model |> gpu
    accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
    test_img, Y = prepare_dataset(train=false)
    batchsize=64
    test_X = [cat(permutedims(float.(test_img[batch]))...,dims=4) for batch in partition(1:length(test_img),batchsize)]
    test_label = gpu.([Y[:,batch] for (i,batch) in enumerate(partition(1:size(Y)[2],batchsize))])
    total_acc=0
    cnt = 0
    for (x,y) in zip(test_X,test_label)
        total_acc += accuracy(x|> gpu, y|>gpu)
        cnt += 1
    end
    println("acc = $(total_acc/cnt)")
    println("Done")
end

function main()
    train()
    predict()
    predict2()
end

main()

Denseレイヤーとは異なる点

このページに初めて来た方は
Julia 1.0 + FluxでMNIST学習
のほうにもっと充実した解説が書いてあります。
ここでは、CNNレイヤーにデータを入力するときの注意点を書きます。

モデルについて

Conv((3,3), 1=>hidden, relu, pad=1, stride=1) でカーネルが3x3でグレースケールの画像1チャンネルをhidden-チャンネルのテンソルのかたちに出力することになります。活性化関数は relu です。

x -> reshape(x, :, size(x, 4)) は 次の層の Dense に渡すために TensorFlow の tf.flatten みたいに次元を平らにします。Chainerだとココらへんは内部で自動的に処理してくれますけどね。

データの入力について

train_X = [cat(permutedims(float.(train_img[batch]))...,dims=4) for batch in partition(1:length(train_img),batchsize)]

ここでデータをバッチサイズごとに分割しています。

cat(permutedims(float.(test_img[batch]))...,dims=4)
はJuliaの画像のデータ、Convの入力形式に癖があるためこういう書き方をする必要があります。

train_img[1] のなかみはGray{N0f8}型を要素に持つ配列になっています。

julia> train_img[1][10:15,10:15]
6×6 Array{Gray{N0f8},2} with eltype ColorTypes.Gray{FixedPointNumbers.Normed{UInt8,8}}:
 Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)  
 Gray{N0f8}(0.275)  Gray{N0f8}(0.341)  Gray{N0f8}(0.631)  Gray{N0f8}(0.898)  Gray{N0f8}(0.635)  Gray{N0f8}(0.765)
 Gray{N0f8}(0.945)  Gray{N0f8}(0.808)  Gray{N0f8}(0.643)  Gray{N0f8}(0.451)  Gray{N0f8}(0.451)  Gray{N0f8}(0.451)
 Gray{N0f8}(0.133)  Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)  
 Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)  
 Gray{N0f8}(0.102)  Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)    Gray{N0f8}(0.0)  

これを float でキャストして

julia> float(train_img[1][10:15,10:15])
6×6 Array{Float64,2}:
 0.0       0.0       0.0       0.0       0.0       0.0     
 0.27451   0.341176  0.631373  0.898039  0.635294  0.764706
 0.945098  0.807843  0.643137  0.45098   0.45098   0.45098 
 0.133333  0.0       0.0       0.0       0.0       0.0     
 0.0       0.0       0.0       0.0       0.0       0.0     
 0.101961  0.0       0.0       0.0       0.0       0.0   
julia> permutedims(float(train_img[1][10:15,10:15]))
6×6 Array{Float64,2}:
 0.0  0.27451   0.945098  0.133333  0.0  0.101961
 0.0  0.341176  0.807843  0.0       0.0  0.0     
 0.0  0.631373  0.643137  0.0       0.0  0.0     
 0.0  0.898039  0.45098   0.0       0.0  0.0     
 0.0  0.635294  0.45098   0.0       0.0  0.0     
 0.0  0.764706  0.45098   0.0       0.0  0.0    

転置します.転置するココロは Convの入力方法がWHCN形式だからです。

cat で画像たちを4番目の軸に関して連結します。

julia> size(cat(permutedims(float.(train_img[[i for i in 1:32]]))...,dims=4))
(28, 28, 1, 32) # WHCN

書いてて気づいたのですが3番目の次元って自動的に軸が追加されるみたいですね。たとえば、dims=6 にすると

julia> size(cat(permutedims(float.(train_img[[i for i in 1:32]]))...,dims=6))
(28, 28, 1, 1, 1, 32) #WHAT??!!

というふるまいをします。

こうしてデータをCNNに投げるようにできればあとは普通に学習を回すだけです。

4
7
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?