3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ゼロから作るDeep Learning 読書メモ#1

Last updated at Posted at 2018-07-29

#はじめに

  • これは個人的なメモである。
  • ゼロから作るDeep Learning を読んだが全く理解できない。
  • とりあえず、サンプルプログラムを実行してみるが、それでも理解できない。
  • サンプルプログラムを書き換えまくって、どうにか理解しようとしてみた。

#参考

ゼロから作るDeep Learning
本書のサンプルコードGit

※今回は 3.6.1 MINISTデータセット のメモ

#環境

neuralnet_mnist.py を読む

元となるサンプルソース

↑画像に書かれた0~9の数値を認識する処理

#書き換えまくったソースをメモとして残す

neuralnet_mnist_gebo.py
#オリジナルソースは https://github.com/oreilly-japan/deep-learning-from-scratch
#カレントディレクトリ deep-learning-from-scratch で実行する必要があります。

# coding: utf-8
import sys, os
sys.path.append(os.curdir)
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax
from PIL import Image
from tkinter import Tk, messagebox

def get_data():
    #(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=False, flatten=True)
    return x_test, t_test

def init_network():
    with open("ch03\sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    # debug
    #print(x.ndim)
    #print(x.shape)
    #print(W1.ndim)
    #print(W1.shape)
    #print(b1.ndim)
    #print(b1.shape)

    # input=784 output=50
    # ここがおまじない
    # 1次元配列x と 2次元配列W1 の 積 に b1 を 足した値を a1 とする
    # x=画像データ , W1 ,b1 は 外部ファイルから読み込んだパラメータ
    a1 = np.dot(x, W1) + b1
    # sigmoid 大小関係はそのまま、a1の値を0から1の範囲に押し込める
    z1 = sigmoid(a1)

    # input=50 output=100
    a2 = np.dot(z1, W2) + b2
    # sigmoid 大小関係はそのまま、a1の値を0から1の範囲に押し込める
    z2 = sigmoid(a2)

    # input=100 output=10
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    #a3とyの大乗関係は変わらないはず、そういう意味ではsoftmax()は不要では
    #print(a3)
    #print(y)

    # yではなく、a3でreturnしても結果は同じになる
    #return y
    return(a3)

def show_msgboxyesno(msg):
    root = Tk()
    root.withdraw()
    truefalse = messagebox.askyesno("確認",msg)
    root.quit()
    return(truefalse)

def show_msgbox(msg):
    root = Tk()
    root.withdraw()
    messagebox.showinfo("message",msg)
    root.quit()

#画像データとテキストを表示
def show_img_and_label(imgdata,textdata):
    #imgdataは28*28の1次元配列
    #debug
    print(imgdata.ndim)     # 1
    print(imgdata.shape)    # (784,)

    img = imgdata.reshape(28, 28)  # 形状を元の画像サイズに変形
    #debug
    print(img.ndim)         # 2
    print(img.shape)        # (28, 28)

    # PIL用のイメージobjectに変換する
    pil_img = Image.fromarray(np.uint8(img))
    # Windows表示
    pil_img.show()

    show_msgbox("表示されている画像は" + str(textdata) + "です")   

# 【ここから】
# テストデータの取得 X=画像データ、t=画像に書かれている文字(テキスト)
x ,t = get_data()

#print(x.ndim)
#print(x.shape)

#print(t.ndim)
#print(t.shape)

# パラメータを取得
network = init_network()
# print(type(network)
# print(network.ndim)
# print(network.shape)

predict_cnt = 0
accuracy_cnt = 0
for i in range(len(x)):

    # 認識させる画像を表示したい場合はここを実行
    # ※非同期処理なので注意
    # show_img_and_label(x[i],t[i])

    # 画像認識処理
    # 画像(x[i])でかかれている数字が何なのかを認識する
    # yは10個の配列で0~9である確率を示す
    y = predict(network, x[i])
    #print(y)

    predict_cnt += 1

    p= np.argmax(y) # 最も確率の高い要素のインデックスを取得
    if p == t[i]:
        # 正解だったらカウンタをプラスする
        accuracy_cnt += 1
    else:
        # 間違った時の情報を見たいとき
        show_img_and_label(x[i],t[i])
        show_msgbox("認識結果は"+ str(p) + "でした\n\n" + str(y))

        # まだ続けるかどうかここでチェックする
        if show_msgboxyesno("まだ続けますか?") == False :
            # もう続けない
            break


# 正解率をprint
print("Accuracy:" + str(float(accuracy_cnt) / predict_cnt))
msg = "回答数=" + str(predict_cnt) + "\n"
msg = msg + "〇 = " + str(accuracy_cnt) + "\n"
msg = msg + "× = " + str(predict_cnt-accuracy_cnt) + "\n"
msg = msg + "正解率は"+ str(float(accuracy_cnt) / predict_cnt) + "でした\n"
show_msgbox(msg)

#メモ

predict(network, x)

  • 引数
    • network : dict型 : ニューラルネットワークによる推論処理のパラメータ
    • x : numpy.ndarray型 : 認識させたい画像データ(28×28ピクセル)
  • 戻り値
    • 画像の認識結果、numpy.ndarray型
    • →10個の配列、画像が0~9のどれに該当するか、の確率
3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?