0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlow2.0を使ってFashion-MNISTを畳み込みニューラルネットワーク(CNN)で学習する(Data Augmentationの効果確認)

Posted at

はじめに

前回、Data Augmentationのやり方を理解したため、前々回のモデルに対してData Augmentationを行うことでどうなるかを見ていきたいと思います。

環境

  • Google Colaboratory
  • TensorFlow 2.0 Alpha

コード

こちらです。
なぜかGitHub上ではうまく開けませんでした。GitHubのURLはこちらです。

コード解説

Data Augmentationの設定

from tensorflow.keras.preprocessing.image import ImageDataGenerator

#datagen = ImageDataGenerator()
datagen = ImageDataGenerator(rotation_range = 10, horizontal_flip = True, zoom_range = 0.1)

今回は回転、左右反転、ズームを行うことにしました。コメント行はData Augmentationを行わない場合と精度比較をするためのもので、どちらかをコメントアウトし実行するようにしました。

学習

import time

num_epoch = 80
start_time = time.time()

#train_accuracies = []
#test_accuracies = []
train_accuracies_with_da = []
test_accuracies_with_da = []


for epoch in range(num_epoch):    
    for image, label in dataset_train:
        for _image, _label in datagen.flow(image, label, batch_size = batch_size):
            train_step(_image, _label)
            break
        
    for test_image, test_label in dataset_test:
        test_step(test_image, test_label)
        
    #train_accuracies.append(train_accuracy.result())
    #test_accuracies.append(test_accuracy.result())
    train_accuracies_with_da.append(train_accuracy.result())
    test_accuracies_with_da.append(test_accuracy.result())
    
    
    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, spent_time: {} min'
    spent_time = time.time() - start_time
    print(template.format(epoch + 1, train_loss.result(), train_accuracy.result() * 100, test_loss.result(), test_accuracy.result() * 100, spent_time / 60))

後ほどグラフ描画するため、train_accuracies_with_daやtest_accuracies_with_daに結果を保存しておきます。Data Augmentationしない場合は、コメントアウトを反転させ、train_accuraciesやtest_accuraciesに結果を保存します。

結果

image.png

80エポック訓練した結果、かろうじてData Augmentationした方が精度がよくなっており、Data Augmentationありの精度は92.2%でした。Train Accuracyの方はData Augmentationするかどうかで過学習度合いに差があるようです。
今回はそれほどData Augmentationの効果は大きくなかったのですが、どれだけの効果を発揮するのかは元のデータセットの数や特性に因ると思います。

次の取り組み

ResNetに取り組みたいと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?