4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

skflowでMNIST(DCNN)

Last updated at Posted at 2016-07-02

skflowでMNIST

skflowとは

TensorFlowをScikit Learnのように簡単に使うためのインターフェース
TensorFlowに含まれている(TF Learn
TensorFlow0.9から独立して、書き方も以下の古いものより書きやすくなっている(TFLearn

実行環境

EC2(AWS)のg2.2xlargeインスタンス(オレゴン = 米国西部)
Python 2.7.6
TensorFlow 0.8.0
scipy 0.17.1(scikit-learnに必要)
scikit-learn 0.17.1

AWSのインスタンスは他人のAMIを使って初期化したが、自分で導入したい場合は以下を参考

EC2のGPU instanceでTensorFlow動かすのにもうソースからのビルドは必要ないっぽい?

ソースコード

今回、DCNN(Deep Convolutional Neural Network)で複雑なネットワークを組む際の参考にするためにやってみたので、パラメーターは適当だし、あまりMNIST用ではないと思う

mnist.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from sklearn import metrics
import tensorflow as tf
from tensorflow.contrib import learn as skflow
from tensorflow.contrib.learn.python.learn.datasets import mnist as source

mnist = source.load_mnist()

def max_pool_2x2(tensor_in):
    return tf.nn.max_pool(tensor_in, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

def my_model(X, y):
    X = tf.reshape(X, [-1, 28, 28, 1])

    with tf.variable_scope('layer1'):
        with tf.variable_scope('conv1'):
            h_conv1 = skflow.ops.conv2d(X, n_filters = 16, filter_shape = [3, 3], bias = True, activation = tf.nn.relu)
        with tf.variable_scope('conv2'):
            h_conv2 = skflow.ops.conv2d(h_conv1, n_filters = 32, filter_shape = [3, 3], bias = True, activation = tf.nn.relu)
            h_pool1 = max_pool_2x2(h_conv2)

    with tf.variable_scope('layer2'):
        with tf.variable_scope('conv3'):
            h_conv3 = skflow.ops.conv2d(h_pool1, n_filters = 64, filter_shape = [3, 3], bias = True, activation = tf.nn.relu)
        with tf.variable_scope('conv4'):
            h_conv4 = skflow.ops.conv2d(h_conv3, n_filters = 128, filter_shape = [3, 3], bias = True, activation = tf.nn.relu)
            h_pool2 = max_pool_2x2(h_conv4)
            h_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 128])

    h_fc = skflow.ops.dnn(h_flat, [1024, 1024], activation = tf.nn.relu, dropout = 0.5)
    return skflow.models.logistic_regression(h_fc, y)

classifier = skflow.TensorFlowEstimator(model_fn = my_model, n_classes = 10, batch_size = 100, steps = 20000, learning_rate = 0.001, optimizer = 'Adam')
classifier.fit(mnist.train.images, mnist.train.labels)
score = metrics.accuracy_score(mnist.test.labels, classifier.predict(mnist.test.images))
print('Accuracy: {0:f}'.format(score))

参考にしたページ

skflow - mnist.py
Introduction to Scikit Flow
TensorFlow Python reference documentation

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?