LoginSignup
2
1

More than 5 years have passed since last update.

TensorFlow2.0を使ってFashion-MNISTを畳み込みニューラルネットワーク(CNN)で学習する

Last updated at Posted at 2019-04-07

はじめに

前回に引き続き、ディープラーニング基礎講座の課題を少しアレンジしたものをTensorFlow 2.0で書き直します。

今回の課題

  • Fashion-MNISTデータセットを畳み込みニューラルネットワーク(CNN)で学習する
    • 制約:全体の実行時間は60分以内
    • 元々の課題はMNISTが対象だったが、MNIST以外のデータも扱ってみたかったので、今回はFashion-MNISTにしてみた
    • 元々の課題は精度も求められていたが、今回はTensorFlowの勉強に重きを置き、精度向上にはあまり取り組まない

参考

環境

  • Google Colaboratory
  • TensorFlow 2.0 Alpha

コード

こちらです。

コード解説

前回詳しめに書きましたので、今回はポイントだけ紹介します。

データセット取得

import tensorflow_datasets as tfds
import tensorflow as tf

#print(tfds.list_builders())
dataset, info = tfds.load('fashion_mnist', as_supervised = True, with_info = True)
dataset_test, dataset_train = dataset['test'], dataset['train']
#print(info)
  • tfds.list_builders()は取り扱っているデータの一覧を返してくれます。'fashion-mnist'が正しいか'fashion_mnist'が正しいか悩んだときに呼び出すとよさそうです
  • tfds.loadの引数にwith_info = Trueを付けると、このデータセットに関する色々な情報(以下の通り)を取得できます。例えばshape、型、データ数などです。特にデータ数はこれ以外から取得するのが難しいように思いますので、初めて扱うデータの場合はこのオプションを付けておいた方がよさそうです
tfds.core.DatasetInfo(
    name='fashion_mnist',
    version=1.0.0,
    description='Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.',
    urls=['https://github.com/zalandoresearch/fashion-mnist'],
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10)
    },
    total_num_examples=70000,
    splits={
        'test': <tfds.core.SplitInfo num_examples=10000>,
        'train': <tfds.core.SplitInfo num_examples=60000>
    },
    supervised_keys=('image', 'label'),
    citation='"""
        @article{DBLP:journals/corr/abs-1708-07747,
          author    = {Han Xiao and
                       Kashif Rasul and
                       Roland Vollgraf},
          title     = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
                       Algorithms},
          journal   = {CoRR},
          volume    = {abs/1708.07747},
          year      = {2017},
          url       = {http://arxiv.org/abs/1708.07747},
          archivePrefix = {arXiv},
          eprint    = {1708.07747},
          timestamp = {Mon, 13 Aug 2018 16:47:27 +0200},
          biburl    = {https://dblp.org/rec/bib/journals/corr/abs-1708-07747},
          bibsource = {dblp computer science bibliography, https://dblp.org}
        }

    """',
    redistribution_info=,
)

モデル定義

from tensorflow.keras.layers import Conv2D, Flatten, Dense, BatchNormalization, Dropout, Activation, MaxPool2D, GlobalAveragePooling2D
from tensorflow.keras import Model

class CNNModel(Model):
    def __init__(self):
        super(CNNModel, self).__init__()
        drop_rate = 0.5

        self._layers = ([
            Conv2D(32, 3), # 28, 28, 1 -> 26, 26, 32
            BatchNormalization(),
            Activation(tf.nn.relu),
            Conv2D(64, 3), # 26, 26, 32 -> 24, 24, 64
            BatchNormalization(),
            Activation(tf.nn.relu),
            MaxPool2D(), # 24, 24, 64 -> 12, 12, 64
            Conv2D(128, 3), # 12, 12, 64 -> 10, 10, 128
            BatchNormalization(),
            Activation(tf.nn.relu),
            Conv2D(256, 3), # 10, 10, 128 -> 8, 8, 256
            BatchNormalization(),
            Activation(tf.nn.relu),
            MaxPool2D(), # 8, 8, 256 -> 4, 4, 256
            Flatten(), # 4, 4, 256 -> 4096
            Dense(256), # 4096 -> 256
            BatchNormalization(),
            Dropout(drop_rate),
            Activation(tf.nn.relu),
            Dense(10, activation = 'softmax') # 256 -> 10                        
        ])                

    def call(self, x):
        for layer in self._layers:
            x = layer(x)
        return x


model = CNNModel()

今回は4層のCNNにしてみました。self.layersというプロパティに層を定義しようとしたのですが、読み取り専用となっており定義できませんでした。そのためself._layersに定義しています(self._layersも元々存在していたのですが、空だったし読み取り専用にもなっていなかったのでそのまま使っちゃいました)。

学習結果

Epoch 1, Loss: 0.4085262715816498, Accuracy: 84.93833923339844, Test Loss: 0.30398645997047424, Test Accuracy: 88.70000457763672, spent_time: 1.1197947025299073 min
Epoch 2, Loss: 0.3328215181827545, Accuracy: 87.7733383178711, Test Loss: 0.2793513834476471, Test Accuracy: 89.74000549316406, spent_time: 2.169168742497762 min
Epoch 3, Loss: 0.29196467995643616, Accuracy: 89.22777557373047, Test Loss: 0.2699553370475769, Test Accuracy: 90.25333404541016, spent_time: 3.2162970105806985 min
Epoch 4, Loss: 0.2630894184112549, Accuracy: 90.2874984741211, Test Loss: 0.2627233862876892, Test Accuracy: 90.58999633789062, spent_time: 4.259984481334686 min
Epoch 5, Loss: 0.2399316430091858, Accuracy: 91.12166595458984, Test Loss: 0.2603335380554199, Test Accuracy: 90.80599975585938, spent_time: 5.3003990888595585 min
Epoch 6, Loss: 0.22022312879562378, Accuracy: 91.83583068847656, Test Loss: 0.26244479417800903, Test Accuracy: 90.99666595458984, spent_time: 6.341218503316243 min
Epoch 7, Loss: 0.20297113060951233, Accuracy: 92.46737670898438, Test Loss: 0.26482954621315, Test Accuracy: 91.08856964111328, spent_time: 7.389575409889221 min
Epoch 8, Loss: 0.18778178095817566, Accuracy: 93.0304183959961, Test Loss: 0.2738790512084961, Test Accuracy: 91.20874786376953, spent_time: 8.4386532386144 min
Epoch 9, Loss: 0.17422623932361603, Accuracy: 93.52851867675781, Test Loss: 0.28297314047813416, Test Accuracy: 91.28999328613281, spent_time: 9.47910236120224 min
Epoch 10, Loss: 0.1628570407629013, Accuracy: 93.94750213623047, Test Loss: 0.2914654016494751, Test Accuracy: 91.32699584960938, spent_time: 10.521370196342469 min

こちらを見てみると、3 Conv+BN+poolingでTest Accuracyが0.921となっているので、そこそこよさそうな結果です。もう少し学習させると結果もよくなりそうです(既に過学習気味ですが)。

次の取り組み

  • 画像の水増し
  • ResNet

あたりに取り組んでみようと思います。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1