LoginSignup
1
3

More than 5 years have passed since last update.

[tensorflow] CNN model構築

Last updated at Posted at 2018-03-13

tensorflowでMNISTデータセットを使った数字認識のためのCNNモデル構築

参考 : https://github.com/tensorflow/models/tree/master/official/mnist

python
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

class Model:

  def __init__(self, data_format):
    if data_format == 'channels_first':
      # NCHW : [入力数, チャンネル数, 高さ, 幅]
      # GPU向け、sess.run()などでsliceしたデータが見れない
      # lenaだったら[-1, 3, 512, 512], -1は入力数を意味
      self._input_shape = [-1, 1, 28, 28]
    else:
      # NHWC : [入力数, 高さ, 幅, チャンネル数]
      # CPU向け、sess.run()などでsliceしたデータが見れる
      # lenaだったら[-1, 512, 512, 3], -1は入力数を意味
      assert data_format == 'channels_last'
      self._input_shape = [-1, 28, 28, 1]

    # Neural Networkの構成要素の定義
    # hidden layer 1 : Conv2D (output : 32次元, conv窓大きさ : 5ピクセル)
    # conv mode : same (zero padding), validだとpaddingしない
    # relu : activation関数で負の入力を0にする
    self.conv1 = tf.layers.Conv2D(
        32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.conv2 = tf.layers.Conv2D(
        64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    # output 1024
    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    # output 10 (labels)
    self.fc2 = tf.layers.Dense(10)
    # 0.4の確率で入力を0にする
    self.dropout = tf.layers.Dropout(0.4)
    # (2, 2)ブロックでscanしてmaxを取る
    self.max_pool2d = tf.layers.MaxPooling2D(
        (2, 2), (2, 2), padding='same', data_format=data_format)

  def __call__(self, inputs, training):
    # inputs : (5000, 784)
    y = tf.reshape(inputs, self._input_shape)
    # (5000, 28, 28, 1)
    y = self.conv1(y)
    # (5000, 28, 28, 32)
    # if mode -> 'valid' : (5000, 24, 24, 32)
    y = self.max_pool2d(y)
    # (5000, 14, 14, 32)
    y = self.conv2(y)
    # (5000, 14, 14, 64)
    y = self.max_pool2d(y)
    # (5000, 7, 7, 64)
    y = tf.layers.flatten(y)
    # (5000, 3136)
    y = self.fc1(y)
    # (5000, 1024)
    # testingだとdropoutしない
    y = self.dropout(y, training=training)
    # (5000, 1024) : 0の数が0.4倍増える
    return self.fc2(y)
    # (5000, 10)

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

# (5000, 784)
val_images = mnist.validation.images
val_labels = mnist.validation.labels

model = Model('channels_last')
logits = model(val_images, training=True)

途中でデータを見たいとき

python
with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   tmp = sess.run(y)
   print(tmp)
   cv2.imshow('', tmp)
   cv2.waitKey(0)
1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3