LoginSignup
3
9

More than 5 years have passed since last update.

MNISTのデータ構造をNumPyを使って理解する

Last updated at Posted at 2019-02-02

MINSTのデータセットの中身を知る

modelにどんなデータを渡しているのか理解するためにデータの中をみてます

[目的]
チュートリアルやサンプルなどで書かれている処理の内容を理解するにはデータの構造がどうなっているのか知る必要ある。
そのために、numpyの基本的な配列操作の扱い方を覚えること、そしてデータがどのようになっているかを知ることでデータの渡し方を理解できるようになりたい。

from keras.datasets import mnist
# データ取得
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# train_imagesの配列を見る
print("配列の軸数:" + str(train_images.ndim))
>>> 配列の軸数:3

# テンソルの形状
print("テンソルの形状:" + str(train_images.shape))
>>> テンソルの形状:(60000, 28, 28)

# データタイプ
print("テンソルのデータ型:" + str(train_images.dtype))
>>> テンソルのデータ型:uint8
# uint8 は 8bit 整数型

次に訓練データの画像をプロットしてみる

# 28px * 28pxの画像を表示してみる 
digit = train_images[4]

# 28 * 28の配列のデータをプロット
import matplotlib.pyplot as plt
plt.imshow(digit, cmap=plt.cm.binary)
plt.show()

68747470733a2f2f71696974612d696d6167652d73746f72652e73332e616d617a6f6e6177732e636f6d2f302f3230323539312f62303232636237612d653432332d333963342d653931302d3462613166316364633463662e706e67.png

配列操作でよく使われるやつ

# NumPyでのテンソルの操作
# train_images 60000の 10 - 1000までを取り出す※2次,3次軸は暗黙的に全て取得
my_slice = train_images[10:100]
my_slice.shape
>>> (90, 28, 28)

# train_images 60000の 10 - 1000までを取り出す※2次,3次軸は明示的に全て取得
my_slice = train_images[10:100, :, :]
my_slice.shape
>>> (90, 28, 28)
# 結果は同じ

# train_images 60000の 10 - 1000までを取り出す※2次,3次軸は0-28(全て)までを取得
my_slice = train_images[10:100, 0:28, 0:28]
my_slice.shape
>>> (90, 28, 28)
# これも結果は同じ

# 1次元目全て、 2次元、3次元目は14-全て
my_slice = train_images[:, 14:, 14:]
my_slice.shape
>>> (60000, 14, 14)

# 1次元目全て、 2次元、3次元目は7 から(28-7)までを取得
my_slice = train_images[:, 7:-7, 7:-7]
my_slice.shape
>>> (60000, 14, 14)

振り返り

MNISTで使われているデータの形が理解できた
1. MNISTでは学習データを3次元テンソルから2次元テンソルへ変換している
(60000, 28, 28) → (60000, (28 * 28))
※MNISTで利用するmodelはベクトルデータを扱うモデルを利用している
※現時点でKeras、Tensorflowのライブラリの中身を見てないので詳しいことはわかっていません

  1. MNISTのラベルデータ(分類データ)は keras.utils.to_categorical()を使ってデータを成形している ※MNISTで利用するmodelはラベルデータをone-hotで渡す仕様らしい ※ほかのデータを受け取るmodelもあるらしい
from keras.utils import to_categorical
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)
# one-hotエンコーディングしてくれる関数

まとめ

MNISTの機械学習の流れの理解に役立った。
レイヤー層の扱いもいろいろルール(お作法)が存在するらしいのでその辺は改めて学びたい。

3
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
9