3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Kaggle / MNIST で予測精度向上を目指す (1. チュートリアル通りにCNN を作る)

Last updated at Posted at 2021-01-03

要約

はじめに

Kaggle の Digit Recognizer というコンペティションは,MNIST という有名な手書き数字の画像を使って学習・予測するタスクである。ここでは学習に畳み込みニューラルネットワーク (Convolutional Neural Network, CNN) を使って,予測精度向上を目指す。

MNIST 向けの CNN の事例はネットなどで数多く見つけることができるが,ネットワーク構造やパラメータの選択理由がよく分からなかった。そこで,私自身がシンプルな CNN からスタートし,予測精度向上のために何を考えてネットワーク構造やパラメータを変えていったのかを書き残していく。「おっ!」という凄いことは書いていないけど,私自身の考えの振り返りと,他の方にとって多少の参考になれば幸いである。

もし誤りや勘違いなどあれば,指摘してもらえるとありがたいです。

対象読者

  • CNN の基本事項 (CNN の畳み込み計算や,max pooling,batch normalization など) を一応知っている方

基本の CNN

データの準備

Kaggle で用意されたデータを読み込んで,学習用に整形していく。ポイントは…

  1. CSV ファイルを pandas.DataFrame で読み込み,TensorFlow で処理するためにそれを numpy.ndarray に変換する
  2. 数字の範囲を 0 ~ 1 に変換するために,255.0 で割る
digit-recognition_CNN1a.py
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session


# データ読み込み
train_data = pd.read_csv("/kaggle/input/digit-recognizer/train.csv")
test_data = pd.read_csv("/kaggle/input/digit-recognizer/test.csv")

# データ数を確認
train_data_len = len(train_data)
test_data_len = len(test_data)
print("Length of train_data ; {}".format(train_data_len))
print("Length of test_data ; {}".format(test_data_len))

# Length of train_data ; 42000
# Length of test_data ; 28000

# ラベルとデータを分離
train_data_y = train_data["label"]
train_data_x = train_data.drop(columns="label")

# TensorFlow で処理するため,panda.DataFrame を numpy.ndarray に変換する
# 意図的にデータ型を float64 に変換する
train_data_x = train_data_x.astype('float64').values.reshape((train_data_len, 28, 28, 1))
test_data = test_data.astype('float64').values.reshape((test_data_len, 28, 28, 1))

# データを 0 ~ 1 の範囲にする
train_data_x /= 255.0
test_data /= 255.0

CNN を作る

CNN は TensorFlow の CNN のチュートリアル をそのまま活用する。

digit-recognition_CNN1a.py
import tensorflow as tf
from tensorflow.keras import layers, models

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()

できあがった CNN は下記の通り。

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                36928     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
_________________________________________________________________

コンパイルして学習を実行する

引き続き TensorFlow のチュートリアル の通りに,できたモデルをコンパイルして,学習を実行する。

digit-recognition_CNN1a.py
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_data_x, train_data_y, epochs=5)

ちなみに,ここではラベルの one-hot encoding は行っていない。この場合,コンパイル時のオプションは loss='sparse_categorical_crossentropy' を指定する。

もし one-hot encoding する場合は,loss='categorical_crossentropy' を指定する。(参考 ; 目的関数の利用方法)

予測と結果保存

学習したモデルを使って,テストデータの予測をして,結果を保存する。tensorflow.keras.models.predict_classes を使って,予測結果のラベルを取得する。

もし各ラベルの確率を知りたいなら,tensorflow.keras.models.predict_proba を使う。

得られた結果から pandas.DataFrame を作って,CSV ファイルで保存する。

digit-recognition_CNN1a.py
prediction = model.predict_classes(test_data, verbose=0)
output = pd.DataFrame({"ImageId" : np.arange(1, 28000+1), "Label":prediction})

output.to_csv('digit_recognizer_CNN1a.csv', index=False)
print("Your submission was successfully saved!")

結果

No 説明 スコア
Ref SVM 0.98375
01 チュートリアル通り 0.98792

前回,Kaggle / MNIST をサポートベクターマシンで頑張る の結果は 0.98375 だったが,それをあっさり越えてきた。さすが CNN。

今後,このスクリプトをベースにして,予測精度向上を目指す。

参考

web site

サンプルスクリプト

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?