はじめに
前回に引き続き、ディープラーニング基礎講座の課題を少しアレンジしたものをTensorFlow 2.0で書き直します。
今回の課題
- Fashion-MNISTデータセットを畳み込みニューラルネットワーク(CNN)で学習する
- 制約:全体の実行時間は60分以内
- 元々の課題はMNISTが対象だったが、MNIST以外のデータも扱ってみたかったので、今回はFashion-MNISTにしてみた
- 元々の課題は精度も求められていたが、今回はTensorFlowの勉強に重きを置き、精度向上にはあまり取り組まない
参考
環境
- Google Colaboratory
- TensorFlow 2.0 Alpha
コード
こちらです。
コード解説
前回詳しめに書きましたので、今回はポイントだけ紹介します。
データセット取得
import tensorflow_datasets as tfds
import tensorflow as tf
#print(tfds.list_builders())
dataset, info = tfds.load('fashion_mnist', as_supervised = True, with_info = True)
dataset_test, dataset_train = dataset['test'], dataset['train']
#print(info)
-
tfds.list_builders()
は取り扱っているデータの一覧を返してくれます。'fashion-mnist'
が正しいか'fashion_mnist'
が正しいか悩んだときに呼び出すとよさそうです -
tfds.load
の引数にwith_info = True
を付けると、このデータセットに関する色々な情報(以下の通り)を取得できます。例えばshape、型、データ数などです。特にデータ数はこれ以外から取得するのが難しいように思いますので、初めて扱うデータの場合はこのオプションを付けておいた方がよさそうです
tfds.core.DatasetInfo(
name='fashion_mnist',
version=1.0.0,
description='Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.',
urls=['https://github.com/zalandoresearch/fashion-mnist'],
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10)
},
total_num_examples=70000,
splits={
'test': <tfds.core.SplitInfo num_examples=10000>,
'train': <tfds.core.SplitInfo num_examples=60000>
},
supervised_keys=('image', 'label'),
citation='"""
@article{DBLP:journals/corr/abs-1708-07747,
author = {Han Xiao and
Kashif Rasul and
Roland Vollgraf},
title = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
Algorithms},
journal = {CoRR},
volume = {abs/1708.07747},
year = {2017},
url = {http://arxiv.org/abs/1708.07747},
archivePrefix = {arXiv},
eprint = {1708.07747},
timestamp = {Mon, 13 Aug 2018 16:47:27 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1708-07747},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""',
redistribution_info=,
)
モデル定義
from tensorflow.keras.layers import Conv2D, Flatten, Dense, BatchNormalization, Dropout, Activation, MaxPool2D, GlobalAveragePooling2D
from tensorflow.keras import Model
class CNNModel(Model):
def __init__(self):
super(CNNModel, self).__init__()
drop_rate = 0.5
self._layers = ([
Conv2D(32, 3), # 28, 28, 1 -> 26, 26, 32
BatchNormalization(),
Activation(tf.nn.relu),
Conv2D(64, 3), # 26, 26, 32 -> 24, 24, 64
BatchNormalization(),
Activation(tf.nn.relu),
MaxPool2D(), # 24, 24, 64 -> 12, 12, 64
Conv2D(128, 3), # 12, 12, 64 -> 10, 10, 128
BatchNormalization(),
Activation(tf.nn.relu),
Conv2D(256, 3), # 10, 10, 128 -> 8, 8, 256
BatchNormalization(),
Activation(tf.nn.relu),
MaxPool2D(), # 8, 8, 256 -> 4, 4, 256
Flatten(), # 4, 4, 256 -> 4096
Dense(256), # 4096 -> 256
BatchNormalization(),
Dropout(drop_rate),
Activation(tf.nn.relu),
Dense(10, activation = 'softmax') # 256 -> 10
])
def call(self, x):
for layer in self._layers:
x = layer(x)
return x
model = CNNModel()
今回は4層のCNNにしてみました。self.layers
というプロパティに層を定義しようとしたのですが、読み取り専用となっており定義できませんでした。そのためself._layers
に定義しています(self._layers
も元々存在していたのですが、空だったし読み取り専用にもなっていなかったのでそのまま使っちゃいました)。
学習結果
Epoch 1, Loss: 0.4085262715816498, Accuracy: 84.93833923339844, Test Loss: 0.30398645997047424, Test Accuracy: 88.70000457763672, spent_time: 1.1197947025299073 min
Epoch 2, Loss: 0.3328215181827545, Accuracy: 87.7733383178711, Test Loss: 0.2793513834476471, Test Accuracy: 89.74000549316406, spent_time: 2.169168742497762 min
Epoch 3, Loss: 0.29196467995643616, Accuracy: 89.22777557373047, Test Loss: 0.2699553370475769, Test Accuracy: 90.25333404541016, spent_time: 3.2162970105806985 min
Epoch 4, Loss: 0.2630894184112549, Accuracy: 90.2874984741211, Test Loss: 0.2627233862876892, Test Accuracy: 90.58999633789062, spent_time: 4.259984481334686 min
Epoch 5, Loss: 0.2399316430091858, Accuracy: 91.12166595458984, Test Loss: 0.2603335380554199, Test Accuracy: 90.80599975585938, spent_time: 5.3003990888595585 min
Epoch 6, Loss: 0.22022312879562378, Accuracy: 91.83583068847656, Test Loss: 0.26244479417800903, Test Accuracy: 90.99666595458984, spent_time: 6.341218503316243 min
Epoch 7, Loss: 0.20297113060951233, Accuracy: 92.46737670898438, Test Loss: 0.26482954621315, Test Accuracy: 91.08856964111328, spent_time: 7.389575409889221 min
Epoch 8, Loss: 0.18778178095817566, Accuracy: 93.0304183959961, Test Loss: 0.2738790512084961, Test Accuracy: 91.20874786376953, spent_time: 8.4386532386144 min
Epoch 9, Loss: 0.17422623932361603, Accuracy: 93.52851867675781, Test Loss: 0.28297314047813416, Test Accuracy: 91.28999328613281, spent_time: 9.47910236120224 min
Epoch 10, Loss: 0.1628570407629013, Accuracy: 93.94750213623047, Test Loss: 0.2914654016494751, Test Accuracy: 91.32699584960938, spent_time: 10.521370196342469 min
こちらを見てみると、3 Conv+BN+poolingでTest Accuracyが0.921となっているので、そこそこよさそうな結果です。もう少し学習させると結果もよくなりそうです(既に過学習気味ですが)。
次の取り組み
- 画像の水増し
- ResNet
あたりに取り組んでみようと思います。