1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

SageMakerを利用した、kerasによるMNISTの3層ニューラルネットワーク解析

Last updated at Posted at 2019-05-10

はじめに

本稿では Udemy の 【4日で体験しよう!】 TensorFlow, Keras, Python 3 で学ぶディープラーニング体験講座 の内容を参考に記事を作成しました。

記事を2つに分けて投稿します。

その1 : SageMaker での環境構築と keras による MNIST の3層ニューラルネットワーク解析 ←(本稿)
その2 : keras による fashion MNIST の3層ニューラルネットワークとCNN(畳み込みニューラルネットワーク)解析

※ Udemy では Anaconda を使用してますが、自分は Amazon SageMaker を使用しました。
※環境構築についても Amazon SageMaker を使用した手順を記載します。

前提

  • AWS のアカウントを用意します。
    ※AWSアカウント作成の流れはこちら

環境構築と準備

Amazon SageMaker でノートブックインスタンスを作成し、Jupyter を開きます。

  • AWS にログインし、Amazon SageMaker を検索。
    01.png

  • 左の一覧からノートブックインスタンスを選択。

02.png
  • 右上のノートブックインスタンスの作成を選択。
03.png
  • ノートブックインスタンス名を入力(インスタンス名は任意)
04.png
  • 下にスライドして、上の赤枠の部分を選択して、新しいロールの作成を選択。
05.png
  • 任意の S3 バケットを選択し、ロールの作成を選択します。
    ※今回は任意の S3 バケットを選択していますが、実際に運用する際は適切な IAM ロールを設定して下さい。
06.png
  • 赤枠のように表示されたら IAM ロールの作成は成功です。
07.png
  • 下にスライドして、ノートブックインスタンスの作成を選択。
08.png
  • 画像上部の赤枠のように表示されたらノートブックインスタンスの作成は完了です。
    ステータスが「 Pending 」 → 「 InService 」 になるまで待ちます。
09.png
  • ステータスが InService になったら、開く Jupyter を選択して Jupyter を開きます。
10.png
  • Jupyter が開けたら、この後ダウンロードしたデータを保存するフォルダを作成します。
    Newを選択。
11.png
  • 一覧から conda_tensorflow_p36 を選択して、ノートブックを作成してください。
012.png
  • ノートブックが開けたら、ノートブックの名前を変更しておきます。Untitled を選択。
13.png
  • 好きな名前を付けて、rename を選択。
14.png
  • これで準備完了です。

MNISTとは

手書き数字画像60,000枚と、テスト画像10,000枚を集めた、画像データセットです。
さらに、手書きの数字「0〜9」に正解ラベルが与えられるデータセットでもあり、画像分類問題で人気の高いデータセットです。

3層のニューラルネットワーク解析

先ほど作成したノートブック上でコードを書いていきます。

モジュールのインポートと設定

  • keras : オープンソースニューラルネットワークライブラリ
  • matplotlib : グラフ描画ライブラリ
from tensorflow import keras
import matplotlib.pyplot as plt
%matplotlib inline

MNISTデータのダウンロード

(x_train,y_train),(x_test,y_test) = keras.datasets.mnist.load_data()

データの確認

  • MNISTの画像は28x28の二次元配列になっています。
  • 各要素は0~255の値を取っていて、色が黒に近いほうが数が大きくなります。
x_train[0]

実行結果

array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,
         18,  18,  18, 126, 136, 175,  26, 166, 255, 247, 127,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  30,  36,  94, 154, 170,
        253, 253, 253, 253, 253, 225, 172, 253, 242, 195,  64,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  49, 238, 253, 253, 253, 253,
        253, 253, 253, 253, 251,  93,  82,  82,  56,  39,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  18, 219, 253, 253, 253, 253,
        253, 198, 182, 247, 241,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  80, 156, 107, 253, 253,
        205,  11,   0,  43, 154,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  14,   1, 154, 253,
         90,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 139, 253,
        190,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  11, 190,
        253,  70,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  35,
        241, 225, 160, 108,   1,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         81, 240, 253, 253, 119,  25,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,  45, 186, 253, 253, 150,  27,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,  16,  93, 252, 253, 187,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0, 249, 253, 249,  64,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,  46, 130, 183, 253, 253, 207,   2,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  39,
        148, 229, 253, 253, 253, 250, 182,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  24, 114, 221,
        253, 253, 253, 253, 201,  78,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  23,  66, 213, 253, 253,
        253, 253, 198,  81,   2,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,  18, 171, 219, 253, 253, 253, 253,
        195,  80,   9,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,  55, 172, 226, 253, 253, 253, 253, 244, 133,
         11,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0, 136, 253, 253, 253, 212, 135, 132,  16,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]], dtype=uint8)
  • x_train に60000個の画像データがあることを確認します。
x_train.shape

実行結果

(60000, 28, 28)
  • 試しに画像を10個表示してみます。
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.title("Label :"+ str(y_train[i]))
    plt.imshow(x_train[i].reshape(28,28))

実行結果

15.png

データの加工

  • 0~255の数を訓練向けに0~1の値に変換するために、各要素を255で割ります。
x_train, x_test = x_train/255.0, x_test/255.0

ニューラルネットワークの層を定義

  • 2次元の入力データを1次元に変換する。
  • 512個のノードに全結合する。活性化関数には relu を使用。
  • 20%の入力をドロップアウト(無視する)。
  • 10個のノードに全結合する。活性化関数には softmax を使用。
model = keras.models.Sequential([
    keras.layers.Flatten(),
    keras.layers.Dense(512,activation = "relu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10,activation = "softmax")    
])

ニューラルネットワークの訓練課程の設定

  • 最適化手法は Adam を使用。
  • 損失関数は sparse_categorical_crossentropy を使用。
  • metrics に accuracy を設定することで、評価を記録するようにします。
model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["accracy"])

モデルの作成

  • train データを使用して、モデルの作成を行います。
  • epoch は訓練データを繰り返し学習する回数を指定します。
  • バッチサイズは未設定で32となります。
model.fit(x_train,y_train,epochs = 10)

実行経過

Epoch 1/10
60000/60000 [==============================] - 11s 178us/sample - loss: 0.2173 - acc: 0.9362
Epoch 2/10
60000/60000 [==============================] - 10s 171us/sample - loss: 0.0970 - acc: 0.9706
Epoch 3/10
60000/60000 [==============================] - 10s 163us/sample - loss: 0.0687 - acc: 0.9781
Epoch 4/10
60000/60000 [==============================] - 10s 165us/sample - loss: 0.0530 - acc: 0.9832
Epoch 5/10
60000/60000 [==============================] - 10s 167us/sample - loss: 0.0431 - acc: 0.9858
Epoch 6/10
60000/60000 [==============================] - 11s 178us/sample - loss: 0.0360 - acc: 0.9877
Epoch 7/10
60000/60000 [==============================] - 10s 172us/sample - loss: 0.0317 - acc: 0.9896
Epoch 8/10
60000/60000 [==============================] - 11s 183us/sample - loss: 0.0276 - acc: 0.9909
Epoch 9/10
60000/60000 [==============================] - 11s 178us/sample - loss: 0.0245 - acc: 0.9911
Epoch 10/10
60000/60000 [==============================] - 10s 171us/sample - loss: 0.0222 - acc: 0.9927

精度の確認

  • test データを使用して、モデルの精度を確認します。
model.evaluate(x_test,y_test)

実行結果

0000/10000 [==============================] - 1s 51us/sample - loss: 0.0759 - acc: 0.9810

[0.07591169230262894, 0.981]
  • 精度は98.1%という、非常に高い数値が得られました。

予測値から結果を確認

  • 実際の予測値から予測があっているか確認します。
pred = model.predict(x_test)
pred[0:10]

実行結果

array([[1.29736716e-12, 9.30834090e-11, 7.09911907e-10, 2.14337233e-05,
        1.56587063e-15, 2.23363661e-10, 2.47713938e-18, 9.99978542e-01,
        3.40327849e-10, 2.27434938e-09],
       [3.55773938e-18, 1.30211094e-11, 1.00000000e+00, 1.87722322e-14,
        1.91654797e-27, 4.13704762e-16, 7.62346064e-16, 4.54291151e-24,
        9.36259100e-13, 1.04953153e-22],
       [3.09756431e-12, 9.99991775e-01, 1.07295807e-06, 4.53352356e-10,
        4.08643110e-08, 6.16563938e-08, 4.79950621e-08, 6.89366345e-07,
        6.31294870e-06, 2.44365552e-11],
       [9.99996305e-01, 7.95472019e-14, 1.61123907e-07, 1.72870496e-12,
        2.32529231e-07, 7.89686705e-10, 3.30556077e-06, 2.74338019e-09,
        3.86724398e-13, 2.41019738e-09],
       [2.74608136e-12, 2.27801014e-16, 1.68423036e-11, 2.01783683e-15,
        9.99999642e-01, 1.44000836e-14, 3.15989465e-12, 1.13154517e-08,
        1.28165349e-12, 3.92654016e-07],
       [1.35878001e-14, 9.99999523e-01, 3.68568474e-11, 1.20259558e-11,
        2.99900527e-09, 1.75071763e-11, 2.74890570e-12, 3.23565473e-07,
        6.75050700e-08, 1.53880847e-12],
       [1.30031841e-18, 1.30200993e-12, 3.26199517e-13, 1.25722195e-14,
        9.99999285e-01, 5.16328161e-11, 1.78035551e-13, 9.06056730e-10,
        6.78496804e-07, 2.09780477e-08],
       [2.88904469e-12, 5.34482569e-10, 4.60125329e-06, 2.10134713e-05,
        3.55165030e-05, 3.45541840e-09, 4.37688724e-14, 2.32960105e-08,
        1.15672283e-09, 9.99938846e-01],
       [4.26845313e-17, 7.12464376e-13, 1.61978278e-05, 6.37089513e-12,
        7.62532437e-09, 9.96215880e-01, 3.75244161e-03, 6.48174006e-14,
        1.55818743e-05, 3.98022865e-10],
       [8.06284459e-18, 1.69497513e-13, 6.81280322e-14, 2.45144662e-07,
        1.10701889e-01, 2.98262873e-11, 3.33688947e-15, 6.34674914e-03,
        4.75299629e-07, 8.82950664e-01]], dtype=float32)
y_test[0:10]

実行結果

array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], dtype=uint8)
  • x_test[0] の予測値に注目すると、およそ99%の確率で7であると予測されています。
  • 正解ラベルを確認すると、y_test[0] が7であることから、予測ができていることがわかります。

おわりに

今回は MNIST データを使用して、3層ニューラルネットワークで解析を行いました。
次の その2 では MNISTより解析の難しい fashion MNIST データの解析を3層ニューラルネットワークとCNN(畳み込みニューラルネットワーク)で行います。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?