Edited at

TensorFlow2.0を使ってFashion-MNISTを畳み込みニューラルネットワーク(CNN)で学習する


はじめに

前回に引き続き、ディープラーニング基礎講座の課題を少しアレンジしたものをTensorFlow 2.0で書き直します。


今回の課題


  • Fashion-MNISTデータセットを畳み込みニューラルネットワーク(CNN)で学習する


    • 制約:全体の実行時間は60分以内

    • 元々の課題はMNISTが対象だったが、MNIST以外のデータも扱ってみたかったので、今回はFashion-MNISTにしてみた

    • 元々の課題は精度も求められていたが、今回はTensorFlowの勉強に重きを置き、精度向上にはあまり取り組まない




参考


環境


  • Google Colaboratory

  • TensorFlow 2.0 Alpha


コード

こちらです。


コード解説

前回詳しめに書きましたので、今回はポイントだけ紹介します。


データセット取得

import tensorflow_datasets as tfds

import tensorflow as tf

#print(tfds.list_builders())
dataset, info = tfds.load('fashion_mnist', as_supervised = True, with_info = True)
dataset_test, dataset_train = dataset['test'], dataset['train']
#print(info)



  • tfds.list_builders()は取り扱っているデータの一覧を返してくれます。'fashion-mnist'が正しいか'fashion_mnist'が正しいか悩んだときに呼び出すとよさそうです


  • tfds.loadの引数にwith_info = Trueを付けると、このデータセットに関する色々な情報(以下の通り)を取得できます。例えばshape、型、データ数などです。特にデータ数はこれ以外から取得するのが難しいように思いますので、初めて扱うデータの場合はこのオプションを付けておいた方がよさそうです

tfds.core.DatasetInfo(

name='fashion_mnist',
version=1.0.0,
description='Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.',
urls=['https://github.com/zalandoresearch/fashion-mnist'],
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10)
},
total_num_examples=70000,
splits={
'test': <tfds.core.SplitInfo num_examples=10000>,
'train': <tfds.core.SplitInfo num_examples=60000>
},
supervised_keys=('image', 'label'),
citation='"""
@article{DBLP:journals/corr/abs-1708-07747,
author = {Han Xiao and
Kashif Rasul and
Roland Vollgraf},
title = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
Algorithms},
journal = {CoRR},
volume = {abs/1708.07747},
year = {2017},
url = {http://arxiv.org/abs/1708.07747},
archivePrefix = {arXiv},
eprint = {1708.07747},
timestamp = {Mon, 13 Aug 2018 16:47:27 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1708-07747},
bibsource = {dblp computer science bibliography, https://dblp.org}
}

"""',
redistribution_info=,
)


モデル定義

from tensorflow.keras.layers import Conv2D, Flatten, Dense, BatchNormalization, Dropout, Activation, MaxPool2D, GlobalAveragePooling2D

from tensorflow.keras import Model

class CNNModel(Model):
def __init__(self):
super(CNNModel, self).__init__()
drop_rate = 0.5

self._layers = ([
Conv2D(32, 3), # 28, 28, 1 -> 26, 26, 32
BatchNormalization(),
Activation(tf.nn.relu),
Conv2D(64, 3), # 26, 26, 32 -> 24, 24, 64
BatchNormalization(),
Activation(tf.nn.relu),
MaxPool2D(), # 24, 24, 64 -> 12, 12, 64
Conv2D(128, 3), # 12, 12, 64 -> 10, 10, 128
BatchNormalization(),
Activation(tf.nn.relu),
Conv2D(256, 3), # 10, 10, 128 -> 8, 8, 256
BatchNormalization(),
Activation(tf.nn.relu),
MaxPool2D(), # 8, 8, 256 -> 4, 4, 256
Flatten(), # 4, 4, 256 -> 4096
Dense(256), # 4096 -> 256
BatchNormalization(),
Dropout(drop_rate),
Activation(tf.nn.relu),
Dense(10, activation = 'softmax') # 256 -> 10
])

def call(self, x):
for layer in self._layers:
x = layer(x)
return x

model = CNNModel()

今回は4層のCNNにしてみました。self.layersというプロパティに層を定義しようとしたのですが、読み取り専用となっており定義できませんでした。そのためself._layersに定義しています(self._layersも元々存在していたのですが、空だったし読み取り専用にもなっていなかったのでそのまま使っちゃいました)。


学習結果

Epoch 1, Loss: 0.4085262715816498, Accuracy: 84.93833923339844, Test Loss: 0.30398645997047424, Test Accuracy: 88.70000457763672, spent_time: 1.1197947025299073 min

Epoch 2, Loss: 0.3328215181827545, Accuracy: 87.7733383178711, Test Loss: 0.2793513834476471, Test Accuracy: 89.74000549316406, spent_time: 2.169168742497762 min
Epoch 3, Loss: 0.29196467995643616, Accuracy: 89.22777557373047, Test Loss: 0.2699553370475769, Test Accuracy: 90.25333404541016, spent_time: 3.2162970105806985 min
Epoch 4, Loss: 0.2630894184112549, Accuracy: 90.2874984741211, Test Loss: 0.2627233862876892, Test Accuracy: 90.58999633789062, spent_time: 4.259984481334686 min
Epoch 5, Loss: 0.2399316430091858, Accuracy: 91.12166595458984, Test Loss: 0.2603335380554199, Test Accuracy: 90.80599975585938, spent_time: 5.3003990888595585 min
Epoch 6, Loss: 0.22022312879562378, Accuracy: 91.83583068847656, Test Loss: 0.26244479417800903, Test Accuracy: 90.99666595458984, spent_time: 6.341218503316243 min
Epoch 7, Loss: 0.20297113060951233, Accuracy: 92.46737670898438, Test Loss: 0.26482954621315, Test Accuracy: 91.08856964111328, spent_time: 7.389575409889221 min
Epoch 8, Loss: 0.18778178095817566, Accuracy: 93.0304183959961, Test Loss: 0.2738790512084961, Test Accuracy: 91.20874786376953, spent_time: 8.4386532386144 min
Epoch 9, Loss: 0.17422623932361603, Accuracy: 93.52851867675781, Test Loss: 0.28297314047813416, Test Accuracy: 91.28999328613281, spent_time: 9.47910236120224 min
Epoch 10, Loss: 0.1628570407629013, Accuracy: 93.94750213623047, Test Loss: 0.2914654016494751, Test Accuracy: 91.32699584960938, spent_time: 10.521370196342469 min

こちらを見てみると、3 Conv+BN+poolingでTest Accuracyが0.921となっているので、そこそこよさそうな結果です。もう少し学習させると結果もよくなりそうです(既に過学習気味ですが)。


次の取り組み


  • 画像の水増し

  • ResNet


あたりに取り組んでみようと思います。