0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlow > sine curveの学習 > v0.5 > placeholderを使っているがlossは0.4が中心となる

Last updated at Posted at 2016-11-19
動作環境
GeForce GTX 1070 (8GB)
ASRock Z170M Pro4S [Intel Z170chipset]
Ubuntu 14.04 LTS desktop amd64
TensorFlow v0.11
cuDNN v5.1 for Linux
CUDA v8.0
Python 2.7.6
IPython 5.1.0 -- An enhanced Interactive Python.

v0.1 http://qiita.com/7of9/items/b364d897b95476a30754

sine curveを学習するコードをplaceholder使用に変更しているがlossが0.4を中心としている。

input.csv生成

original (placeholder不使用)

linreg2.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import tensorflow as tf
import tensorflow.contrib.slim as slim

# ファイル名の Queue を作成
filename_queue = tf.train.string_input_producer(["input.csv"])

# CSV を parse
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
input1, output = tf.decode_csv(value, record_defaults=[[0.], [0.]])
inputs = tf.pack([input1])
output = tf.pack([output])

batch_size=4 # [4]
inputs_batch, output_batch = tf.train.shuffle_batch([inputs, output], batch_size, capacity=40, min_after_dequeue=batch_size)

## NN のグラフ生成
hiddens = slim.stack(inputs_batch, slim.fully_connected, [1,7,7,7], 
  activation_fn=tf.nn.sigmoid, scope="hidden")
prediction = slim.fully_connected(hiddens, 1, activation_fn=tf.nn.sigmoid, scope="output")

loss = tf.contrib.losses.mean_squared_error(prediction, output_batch)
#train_op = slim.learning.create_train_op(loss, tf.train.AdamOptimizer(0.01))
train_op = slim.learning.create_train_op(loss, tf.train.AdamOptimizer(0.001))

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  try:
    sess.run(init_op)
    for i in range(30000): #[10000]
      _, t_loss = sess.run([train_op, loss])
      if (i+1) % 100 == 0:
        print("%d,%f" % (i+1, t_loss))
#        print("%d,%f,#step, loss" % (i+1, t_loss))
  finally:
    coord.request_stop()

  coord.join(threads)

placeholder使用

linreg2_feeddict.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import tensorflow as tf
import tensorflow.contrib.slim as slim

# ファイル名の Queue を作成
filename_queue = tf.train.string_input_producer(["input.csv"])

# CSV を parse
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
input1, output = tf.decode_csv(value, record_defaults=[[0.], [0.]])
inputs = tf.pack([input1])
output = tf.pack([output])

batch_size=4 # [4]
inputs_batch, output_batch = tf.train.shuffle_batch([inputs, output], batch_size, capacity=40, min_after_dequeue=batch_size)

input_ph = tf.placeholder("float", [None,1])
output_ph = tf.placeholder("float",[None,1])

## NN のグラフ生成
hiddens = slim.stack(input_ph, slim.fully_connected, [1,7,7,7], 
  activation_fn=tf.nn.sigmoid, scope="hidden")
prediction = slim.fully_connected(hiddens, 1, activation_fn=tf.nn.sigmoid, scope="output")
loss = tf.contrib.losses.mean_squared_error(prediction, output_ph)

#train_op = slim.learning.create_train_op(loss, tf.train.AdamOptimizer(0.01))
train_op = slim.learning.create_train_op(loss, tf.train.AdamOptimizer(0.001))

def feed_dict(inputs, output):
    return {input_ph: inputs.eval(), output_ph: output.eval()}

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  try:
    sess.run(init_op)
    for i in range(30000): #[10000]
      _, t_loss = sess.run([train_op, loss], feed_dict=feed_dict(inputs_batch, output_batch))
      if (i+1) % 100 == 0:
        print("%d,%f" % (i+1, t_loss))
#        print("%d,%f,#step, loss" % (i+1, t_loss))
  finally:
    coord.request_stop()

  coord.join(threads)

結果

python linreg2.py > log.learn_original
python linreg2_feeddict.py > log.learn_placeholder

matplotlibコード on Jupyter

参考 http://qiita.com/ynakayama/items/8d3b1f7356da5bcbe9bc

%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

#data = np.loadtxt('log.learn_non_qmc', delimiter=',')
data1 = np.loadtxt('log.learn_original', delimiter=',')
#data = np.loadtxt('log.learn_original.batch1', delimiter=',')
#data = np.loadtxt('log.learn_no_reshape', delimiter=',')
data2 = np.loadtxt('log.learn_placeholder', delimiter=',')

input1 = data1[:,0]
output1 = data1[:,1]
input2 = data2[:,0]
output2 = data2[:,1]

fig = plt.figure()
ax1 = fig.add_subplot(2,1,1)
ax2 = fig.add_subplot(2,1,2)

#ax.plot(input1, output1, color='black', linestyle='dotted', label='rate=0.001')
ax1.plot(input1, output1, color='black', linestyle='solid', label='original')
ax2.plot(input2, output2, color='red', linestyle='solid', label='placeholder')
#ax.scatter(input1, output1)

ax1.set_title('loss')
ax1.set_xlabel('step')
ax1.set_ylabel('loss')
ax1.grid(True)
ax1.legend()

#ax2.set_title('loss')
ax2.set_xlabel('step')
ax2.set_ylabel('loss')
ax2.grid(True)
ax2.legend()

fig.show()

qiita.png

2つのコードのdiff

$ diff linreg2.py linreg2_feeddict.py 
20a21,23
> input_ph = tf.placeholder("float", [None,1])
> output_ph = tf.placeholder("float",[None,1])
> 
22c25
< hiddens = slim.stack(inputs_batch, slim.fully_connected, [1,7,7,7], 
---
> hiddens = slim.stack(input_ph, slim.fully_connected, [1,7,7,7], 
24a28
> loss = tf.contrib.losses.mean_squared_error(prediction, output_ph)
26d29
< loss = tf.contrib.losses.mean_squared_error(prediction, output_batch)
29a33,35
> def feed_dict(inputs, output):
>     return {input_ph: inputs.eval(), output_ph: output.eval()}
> 
39c45
<       _, t_loss = sess.run([train_op, loss])
---
>       _, t_loss = sess.run([train_op, loss], feed_dict=feed_dict(inputs_batch, output_batch))
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?