はじめに
本稿では Udemy の 【4日で体験しよう!】 TensorFlow, Keras, Python 3 で学ぶディープラーニング体験講座 の内容を参考に記事を作成しました。
記事を2つに分けて投稿します。
その1 : SageMaker での環境構築と keras による MNIST の3層ニューラルネットワーク解析 ←(本稿)
その2 : keras による fashion MNIST の3層ニューラルネットワークとCNN(畳み込みニューラルネットワーク)解析
※ Udemy では Anaconda を使用してますが、自分は Amazon SageMaker を使用しました。
※環境構築についても Amazon SageMaker を使用した手順を記載します。
前提
- AWS のアカウントを用意します。
※AWSアカウント作成の流れはこちら
環境構築と準備
Amazon SageMaker でノートブックインスタンスを作成し、Jupyter を開きます。
- 右上のノートブックインスタンスの作成を選択。
- ノートブックインスタンス名を入力(インスタンス名は任意)
- 下にスライドして、上の赤枠の部分を選択して、新しいロールの作成を選択。
- 任意の S3 バケットを選択し、ロールの作成を選択します。
※今回は任意の S3 バケットを選択していますが、実際に運用する際は適切な IAM ロールを設定して下さい。
- 赤枠のように表示されたら IAM ロールの作成は成功です。
- 下にスライドして、ノートブックインスタンスの作成を選択。
- 画像上部の赤枠のように表示されたらノートブックインスタンスの作成は完了です。
ステータスが「 Pending 」 → 「 InService 」 になるまで待ちます。
- ステータスが InService になったら、開く Jupyter を選択して Jupyter を開きます。
- Jupyter が開けたら、この後ダウンロードしたデータを保存するフォルダを作成します。
Newを選択。
- 一覧から conda_tensorflow_p36 を選択して、ノートブックを作成してください。
- ノートブックが開けたら、ノートブックの名前を変更しておきます。Untitled を選択。
- 好きな名前を付けて、rename を選択。
- これで準備完了です。
MNISTとは
手書き数字画像60,000枚と、テスト画像10,000枚を集めた、画像データセットです。
さらに、手書きの数字「0〜9」に正解ラベルが与えられるデータセットでもあり、画像分類問題で人気の高いデータセットです。
3層のニューラルネットワーク解析
先ほど作成したノートブック上でコードを書いていきます。
モジュールのインポートと設定
- keras : オープンソースニューラルネットワークライブラリ
- matplotlib : グラフ描画ライブラリ
from tensorflow import keras
import matplotlib.pyplot as plt
%matplotlib inline
MNISTデータのダウンロード
(x_train,y_train),(x_test,y_test) = keras.datasets.mnist.load_data()
データの確認
- MNISTの画像は28x28の二次元配列になっています。
- 各要素は0~255の値を取っていて、色が黒に近いほうが数が大きくなります。
x_train[0]
実行結果
array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
18, 18, 18, 126, 136, 175, 26, 166, 255, 247, 127, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170,
253, 253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253,
253, 253, 253, 253, 251, 93, 82, 82, 56, 39, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253,
253, 198, 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253,
205, 11, 0, 43, 154, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253,
90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253,
190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190,
253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35,
241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39,
148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221,
253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253,
253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253,
195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133,
11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]], dtype=uint8)
- x_train に60000個の画像データがあることを確認します。
x_train.shape
実行結果
(60000, 28, 28)
- 試しに画像を10個表示してみます。
for i in range(10):
plt.subplot(2,5,i+1)
plt.title("Label :"+ str(y_train[i]))
plt.imshow(x_train[i].reshape(28,28))
実行結果
データの加工
- 0~255の数を訓練向けに0~1の値に変換するために、各要素を255で割ります。
x_train, x_test = x_train/255.0, x_test/255.0
ニューラルネットワークの層を定義
- 2次元の入力データを1次元に変換する。
- 512個のノードに全結合する。活性化関数には relu を使用。
- 20%の入力をドロップアウト(無視する)。
- 10個のノードに全結合する。活性化関数には softmax を使用。
model = keras.models.Sequential([
keras.layers.Flatten(),
keras.layers.Dense(512,activation = "relu"),
keras.layers.Dropout(0.2),
keras.layers.Dense(10,activation = "softmax")
])
ニューラルネットワークの訓練課程の設定
- 最適化手法は Adam を使用。
- 損失関数は sparse_categorical_crossentropy を使用。
- metrics に accuracy を設定することで、評価を記録するようにします。
model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["accracy"])
モデルの作成
- train データを使用して、モデルの作成を行います。
- epoch は訓練データを繰り返し学習する回数を指定します。
- バッチサイズは未設定で32となります。
model.fit(x_train,y_train,epochs = 10)
実行経過
Epoch 1/10
60000/60000 [==============================] - 11s 178us/sample - loss: 0.2173 - acc: 0.9362
Epoch 2/10
60000/60000 [==============================] - 10s 171us/sample - loss: 0.0970 - acc: 0.9706
Epoch 3/10
60000/60000 [==============================] - 10s 163us/sample - loss: 0.0687 - acc: 0.9781
Epoch 4/10
60000/60000 [==============================] - 10s 165us/sample - loss: 0.0530 - acc: 0.9832
Epoch 5/10
60000/60000 [==============================] - 10s 167us/sample - loss: 0.0431 - acc: 0.9858
Epoch 6/10
60000/60000 [==============================] - 11s 178us/sample - loss: 0.0360 - acc: 0.9877
Epoch 7/10
60000/60000 [==============================] - 10s 172us/sample - loss: 0.0317 - acc: 0.9896
Epoch 8/10
60000/60000 [==============================] - 11s 183us/sample - loss: 0.0276 - acc: 0.9909
Epoch 9/10
60000/60000 [==============================] - 11s 178us/sample - loss: 0.0245 - acc: 0.9911
Epoch 10/10
60000/60000 [==============================] - 10s 171us/sample - loss: 0.0222 - acc: 0.9927
精度の確認
- test データを使用して、モデルの精度を確認します。
model.evaluate(x_test,y_test)
実行結果
0000/10000 [==============================] - 1s 51us/sample - loss: 0.0759 - acc: 0.9810
[0.07591169230262894, 0.981]
- 精度は98.1%という、非常に高い数値が得られました。
予測値から結果を確認
- 実際の予測値から予測があっているか確認します。
pred = model.predict(x_test)
pred[0:10]
実行結果
array([[1.29736716e-12, 9.30834090e-11, 7.09911907e-10, 2.14337233e-05,
1.56587063e-15, 2.23363661e-10, 2.47713938e-18, 9.99978542e-01,
3.40327849e-10, 2.27434938e-09],
[3.55773938e-18, 1.30211094e-11, 1.00000000e+00, 1.87722322e-14,
1.91654797e-27, 4.13704762e-16, 7.62346064e-16, 4.54291151e-24,
9.36259100e-13, 1.04953153e-22],
[3.09756431e-12, 9.99991775e-01, 1.07295807e-06, 4.53352356e-10,
4.08643110e-08, 6.16563938e-08, 4.79950621e-08, 6.89366345e-07,
6.31294870e-06, 2.44365552e-11],
[9.99996305e-01, 7.95472019e-14, 1.61123907e-07, 1.72870496e-12,
2.32529231e-07, 7.89686705e-10, 3.30556077e-06, 2.74338019e-09,
3.86724398e-13, 2.41019738e-09],
[2.74608136e-12, 2.27801014e-16, 1.68423036e-11, 2.01783683e-15,
9.99999642e-01, 1.44000836e-14, 3.15989465e-12, 1.13154517e-08,
1.28165349e-12, 3.92654016e-07],
[1.35878001e-14, 9.99999523e-01, 3.68568474e-11, 1.20259558e-11,
2.99900527e-09, 1.75071763e-11, 2.74890570e-12, 3.23565473e-07,
6.75050700e-08, 1.53880847e-12],
[1.30031841e-18, 1.30200993e-12, 3.26199517e-13, 1.25722195e-14,
9.99999285e-01, 5.16328161e-11, 1.78035551e-13, 9.06056730e-10,
6.78496804e-07, 2.09780477e-08],
[2.88904469e-12, 5.34482569e-10, 4.60125329e-06, 2.10134713e-05,
3.55165030e-05, 3.45541840e-09, 4.37688724e-14, 2.32960105e-08,
1.15672283e-09, 9.99938846e-01],
[4.26845313e-17, 7.12464376e-13, 1.61978278e-05, 6.37089513e-12,
7.62532437e-09, 9.96215880e-01, 3.75244161e-03, 6.48174006e-14,
1.55818743e-05, 3.98022865e-10],
[8.06284459e-18, 1.69497513e-13, 6.81280322e-14, 2.45144662e-07,
1.10701889e-01, 2.98262873e-11, 3.33688947e-15, 6.34674914e-03,
4.75299629e-07, 8.82950664e-01]], dtype=float32)
y_test[0:10]
実行結果
array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], dtype=uint8)
- x_test[0] の予測値に注目すると、およそ99%の確率で7であると予測されています。
- 正解ラベルを確認すると、y_test[0] が7であることから、予測ができていることがわかります。
おわりに
今回は MNIST データを使用して、3層ニューラルネットワークで解析を行いました。
次の その2 では MNISTより解析の難しい fashion MNIST データの解析を3層ニューラルネットワークとCNN(畳み込みニューラルネットワーク)で行います。