0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

てすと

Posted at
"""
Created on Mon Nov 26 23:40:35 2018

@author: jin
"""

import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.learn as learn
import copy
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec

x = tf.placeholder(tf.float32, shape=[None, 784])
y = tf.placeholder(tf.float32, shape=[None, 10])


def cnn_model(x,y):
    x = tf.reshape(x,[-1, 28, 28, 1])
    y = slim.one_hot_encoding(y, 10)

    with slim.arg_scope(
        [slim.conv2d],
        #活性化関数
        activation_fn=tf.nn.relu,
        #重み
        weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
        #バイアス
        biases_initializer=tf.constant_initializer(0.1)
    ):
        with slim.arg_scope([slim.max_pool2d], padding='SAME'):
            #畳み込み層1
            c1 = slim.conv2d(x, 32, [5,5])
            #プーリング層1
            p1 = slim.max_pool2d(c1, [2,2])
            #畳み込み層2
            c2 = slim.conv2d(p1, 64, [5,5])
            #プーリング層2
            p2 = slim.max_pool2d(c2, [2,2])
            #全結合層
            p2_flat = slim.flatten(p2)
            fc1 = slim.fully_connected(p2_flat, 1024)
            #dropout
            #prob = tf.placeholder(tf.float32)
            #dropout = slim.dropout(fc1, prob)
            #読み出し層
            fc2 = slim.fully_connected(fc1, 10, activation_fn=None)

            prob = slim.softmax(fc2)

            cross_entropy = slim.losses.softmax_cross_entropy(prob, y)

            train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)


    return prob,cross_entropy,train_step


mnist_o = np.genfromtxt(r'C:\Users\jin\Desktop\digits.csv', delimiter=",")
#1列目を消去
mnist = np.delete(mnist_o,[0],0)
#ラベル列の分離
mnist_label,mnist_data = np.hsplit(mnist,[1])
#学習行とテスト行のsplit
x_train, x_test = np.vsplit(mnist_data,[8000])
y_train, y_test = np.vsplit(mnist_label,[8000])

tf.logging.set_verbosity(tf.logging.INFO)
validation_metrics = {
    "accuracy" : MetricSpec(
        metric_fn=tf.contrib.metrics.streaming_accuracy,
        prediction_key="class")
}
validation_monitor = learn.monitors.ValidationMonitor(
        x_test,
        y_test,
        metrics=validation_metrics,
        every_n_steps=100)

classifier = learn.Estimator(model_fn=cnn_model, model_dir='/tmp/cnn_log',
    config=learn.RunConfig(save_checkpoints_secs=10))
classifier.fit(x=x_train, y=y_train, steps=3200, batch_size=64,
    monitors=[validation_monitor])
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?