(目次はこちら)
はじめに
むかしむかしの記事で、Deep MNIST for Expertsに相当するCNNについて書いた。
数日前に、Fashion-MNISTというデータセットが公開されたので、同じモデルを試してみようかと。
Fashion-MNIST
A dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. Fashion-MNIST is intended to serve as a direct drop-in replacement of the original MNIST dataset for benchmarking machine learning algorithms.
要するに、MNISTとの違いは、内容が数字ではなく、ファッションアイテムとのこと。画像サイズも、画像数も同じ。
クラスは、
| Label | Description |
|---|---|
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |
で、これを作った動機が、MNIST is too easy だからとのこと。
ベンチマークが公開されていて、MNIST / Fashion-MNISTともに、SVC / {"C":10,"kernel":"poly"} がもっとも高い精度で、
| Dataset | Accuracy |
|---|---|
| Fashion-MNIST | 0.897 |
| MNIST | 0.978 |
| らしい。 | |
| これがいささか疑問で、なぜ、SVM (SVC) ? | |
| CNNは?と。 | |
| まぁ、そのうち追加されるとは思いますが・・ |
コード
すごく昔に書いたものの流用なので、tensorflow 1.xでは動かないので、v0.12.1を使いました・・・
mnist_cnn_ml.py
をちょっとだけ改変して利用。
diff --git a/src/tensorflow/mnist_cnn_ml.py b/src/tensorflow/mnist_cnn_ml.py
index 6a7b678..c662bce 100644
--- a/src/tensorflow/mnist_cnn_ml.py
+++ b/src/tensorflow/mnist_cnn_ml.py
@@ -10,12 +10,12 @@ FILTER_NUM = 32
FILTER_NUM2 = 64
FEATURE_DIM = 1024
KEEP_PROB = 0.5
-TRAINING_LOOP = 20000
-BATCH_SIZE = 50
+TRAINING_LOOP = 100000
+BATCH_SIZE = 128
SUMMARY_DIR = 'log_cnn_ml'
SUMMARY_INTERVAL = 100
-mnist = input_data.read_data_sets('data', one_hot=True)
+mnist = input_data.read_data_sets('data/fashion', one_hot=True)
with tf.Graph().as_default():
with tf.name_scope('input'):
@@ -58,7 +58,7 @@ with tf.Graph().as_default():
with tf.name_scope('optimize'):
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
- train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy)
+ train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
with tf.Session() as sess:
train_writer = tf.train.SummaryWriter(SUMMARY_DIR + '/train', sess.graph)
結果
まだ、精度は上がると思いますが、このモデルで精度を追う意味はあまりない(そして眠い)ので途中で止めました。
SVMの0.897を明らかに超えて、0.929という結果になりました。MNISTの場合、0.993だったので、Fashion-MNISTの方が確かに識別問題としては難しいと言えます。

おわりに
MNISTのFashionデータ版である、Fashion-MNISTに対して、Deep MNIST for Experts相当のCNNを適用して、ベンチマークにあるSVMの0.897を明らかに超えることがわかった。(やる前から想像できていたことですが・・。) そして、MNISTに比べて、Fashion-MNISTのほうが識別問題としては難易度が上がっていることも確認できた。
==追記==
はやくもいくつか結果がアップデートされてますね
WRN40-4で、0.967だそうです。
