Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

MNISTに畳み込み処理を追加する

More than 3 years have passed since last update.

以前、TensorFlowの処理を理解するために、MNISTのコードを読んでいくことをしてみたが、やっぱり画像なんだからCNNしたいよねということで、畳み込み処理を追加してみました。
参考はもちろん「Deep MNIST for Experts」ですが、こちらも結構参考にさせてもらいました。

畳み込み処理

ズバリ、n×nのフィルターを掛けること。
画像処理をやっている人には、エッジを抽出したり、ぼかしたりしてきたアレですので、非常になじみ深いです。
20161226145014.png

使用する関数は「tensorflow.nn.conv2d()」で、設定する引数は主に4つ。
1つ目は入力データ。
2つ目はフィルター(縦サイズ、横サイズ、入力チャンネル数、出力チャンネル数)。
3つ目は移動ストライド。
4つ目はパディングの設定。

通常はこんな感じで設定します。

x = tensorflow.placeholder(tf.float32, [None, 784])
x_image = tensorflow.reshape(x, [-1, 28, 28, 1])
initial = tensorflow.truncated_normal([5, 5, 1, 32], stddev=0.1)
W_conv1 = tensorflow.Variable(initial)
h_conv1 = tensorflow.nn.conv2d(x_image,
                               W_conv1,
                               strides=[1, 1, 1, 1],
                               padding='SAME')

MNISTのデータは28×28の画像が1次元に伸びた状態なので、まず28×28に戻しています。
次に5×5のフィルターを設定しています。入力は白黒画像なので1チャンネルで、出力は32チャンネルにしています。

なお、畳み込み処理時には、活性化関数としてReLUを使ったりしますので、実際には最後の行は

initial = tensorflow.constant(0.1, [32])
b_conv1 = tensorflow.Variable(initial)
h_conv1 = tensorflow.nn.relu(tensorflow.nn.conv2d(x_image,
                                                  W_conv1,
                                                  strides=[1, 1, 1, 1],
                                                  padding='SAME')
                             + b_conv1)

といった感じになります。
relu.png
(「W_conv1」や「b_conv1」が、誤差逆伝播で調整される重みパラメータになります)

プーリング処理

通常、畳み込み処理と対になっているのがプーリング処理です。
こちらは画像サイズの縮小を行うイメージになります。
(n×nの代表値を求める)
filter.png
平均値だったり中央値だったり、いろいろ方法はありますが、Deep Learningでは、いわゆる「Max Poolong」といわれる最大値を使用することが多いです。
20150126055504.png

TensorFlowでは、「tensorflow.nn.max_pool()」を使用します。
設定する引数は、主に4つ。
1つ目は入力データ。
2つ目はフィルター。
3つ目が移動ストライド。
4つ目はパディングの設定。
畳み込み処理と似てますよね。

実際のコードはこんな感じになります。

h_pool1 = tensorflow.nn.max_pool(h_conv1,
                                 ksize=[1, 2, 2, 1],
                                 strides=[1, 2, 2, 1],
                                 padding='SAME')

入力は畳み込み処理の結果になります。
フィルターのサイズは2×2、ストライドも縦横2ピクセル移動することにしています。これにより、出来上がりの画像サイズは、元の半分になります。
(チャンネル数は変わりません)

全結合処理のために

なお、畳み込み処理やプーリング処理は2次元の画像として処理をしているため、全結合する際に1次元に変換します。
この時、ちゃんと画像サイズとチャンネル数が分かっていないと配列のサイズが設定できないため、注意してください。
関数は、1次元→2次元にした時と同じように、「tensorflow.reshape()」を使用します。

おまけ

まだ説明していない部分もありますが、動く(はず)のソースコードを載せておきます。
※2017/03/21現在、MNISTのデータセットが置かれているサイトにアクセスできなくなっていますのでご注意ください。

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# オリジナルの関数群
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                          strides=[1, 2, 2, 1], padding='SAME')

# メイン関数
def main():
    # データセットの取得
    # (今は、あらかじめダウンロードしてあるzipのフォルダを指定)
    mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

    # 入出力データの準備
    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10])
    x_image = tf.reshape(x, [-1, 28, 28, 1])

    # 畳み込み処理(1)
    W_conv1 = weight_variable([5, 5, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)

    # プーリング処理(1)
    h_pool1 = max_pool_2x2(h_conv1)

    # 畳み込み処理(2)
    W_conv2 = weight_variable([5, 5, 32, 64])
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

    # プーリング処理(2)
    h_pool2 = max_pool_2x2(h_conv2)

    # 全結合処理
    W_fc1 = weight_variable([7 * 7 * 64, 1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

    # ドロップアウト
    keep_prob = tf.placeholder(tf.float32)
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

    # 識別(なぜsoftmaxがないのかは謎)
    W_fc2 = weight_variable([1024, 10])
    b_fc2 = bias_variable([10])
    y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

    # 評価処理(なぜかこっちにはsoftmaxがある)
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # セッション作成
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()

    # トレーニング
    for i in range(20000):
        # バッチサイズは50
        batch = mnist.train.next_batch(50)
        # 100回ごとに経過表示
        if i%100 == 0:
            train_accuracy = accuracy.eval(feed_dict={
                x:batch[0], y_: batch[1], keep_prob: 1.0})
            print("step %d, training accuracy %g"%(i, train_accuracy))
        train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

    # 結果表示
    print("test accuracy %g"%accuracy.eval(feed_dict={
        x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

# 処理開始
if __name__ == "__main__":
    main()

あとは、ドロップアウトと評価処理のところを書かないとね...
(評価処理、よくわからない ^^;)

mine820
もうプログラムを書かなくなって久しい、元アプリケーションエンジニアのおじさんです。 C言語万能教に侵されています。 OpenGLとOpenCVとDirectShowでメシ食ってました。 たまには趣味で書いていこうかと思っています。 ※発信の内容は全て個人の見解に基づくもので、所属する組織の見解ではありません
ark-info-sys
最新IT技術を利用したソフト開発、パッケージ製品の企画開発を中心に、お客様に満足いただけるサービスをご提供いたします。
https://www.ark-info-sys.co.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away