LoginSignup
1
0

More than 3 years have passed since last update.

敢えてRubyで学ぶ「ゼロから作るDeep Learning」第3章のニューラルネットワークのMNIST推論処理

Posted at

PyCallで取り込んだNumPyのデータを無理やりNumo::NArrayに変換してRubyで行列計算をする。

require 'numo/narray'
require 'numo/gnuplot'
require 'datasets'
require 'mini_magick'
require 'pycall/import'
require './neuralnet'
include PyCall::Import

def init_network
  pyimport :numpy
  pyimport :pickle
  pkl = open("sample_weight.pkl", "rb")
  pickle.load(pkl)
end

# 独自適当メソッド
# 無理やりnumpy形式をlist -> Ruby標準Array -> NArrayに変換する荒業
def numpy_to_narray(w)
  row = w.shape[0]
  return Numo::NArray.concatenate(Array(w.tolist).flatten) if w.shape.length == 1
  col = w.shape[1]
  Numo::NArray.concatenate(Array(w.tolist).flatten).reshape(row, col)
end

def predict(network, x)
  w1 = numpy_to_narray(network['W1'])
  w2 = numpy_to_narray(network['W2'])
  w3 = numpy_to_narray(network['W3'])
  b1 = numpy_to_narray(network['b1'])
  b2 = numpy_to_narray(network['b2'])
  b3 = numpy_to_narray(network['b3'])

  a1 = x.dot(w1) + b1
  z1 = sigmoid(a1)
  a2 = z1.dot(w2) + b2
  z2 = sigmoid(a2)
  a3 = z2.dot(w3) + b3
  softmax(a3)
end

# MNISTデータ取り込み
# train = Datasets::MNIST.new(type: :train)
test = Datasets::MNIST.new(type: :test)
x = Numo::NArray.concatenate(test.map{|t| t.pixels }).reshape(10000,784)
t = Numo::NArray.concatenate(test.map{|t| t.label })
network = init_network
# とりあえず一気に変換
y = predict(network, x)
accuracy_cnt = 0
10000.times do |i|
  p = y[i,true].to_a.index(y[i,true].max)
  accuracy_cnt += 1 if p == t[i]
end

p "Accuracy:#{Float(accuracy_cnt/10000.to_f)}"
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0