0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Python3 + TensorFlow v1.1対応 > sine curveの学習

Last updated at Posted at 2017-07-08
動作環境
GeForce GTX 1070 (8GB)
ASRock Z170M Pro4S [Intel Z170chipset]
Ubuntu 16.04 LTS desktop amd64
TensorFlow v1.1.0
cuDNN v5.1 for Linux
CUDA v8.0
Python 3.5.2
IPython 6.0.0 -- An enhanced Interactive Python.
gcc (Ubuntu 5.4.0-6ubuntu1~16.04.4) 5.4.0 20160609
GNU bash, version 4.3.48(1)-release (x86_64-pc-linux-gnu)

TensorFlow > sine curveの学習 > TensorFlowコードでpredictionをグラフ化してみた > sine curveになっていなかった > sine curveになった ( 誤差:0.01以下)
はUbuntu 14.04 LTS + TensorFlow v0.11 + Python2で実行していた。

Ubuntu 16.04 LTS + TensorFlow v1.1 + Python 3.5.2用にコードを変更した。

prep_data.py
import numpy as np
import random

"""
v0.2, Jul. 08, 2017
   - modify for Python3
"""

# codingrule: PEP8

numdata = 100
x_data = np.random.rand(numdata)
y_data = np.sin(2*np.pi*x_data) + 0.3 * np.random.rand()

for xs, ys in zip(x_data, y_data):
    print('%.5f, %.5f' % (xs, ys))

以下はPEP8対応していない。TFReocrd版は対応予定。

sigmoid_onlyHidden.py
# !/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

"""
v1.1 Jul. 8, 2017
   - modify for TensorFlow v1.1 (pre. v0.11)
   - modify for Python 3
"""

filename_queue = tf.train.string_input_producer(["input.csv"])

# parse CSV
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
input1, output = tf.decode_csv(value, record_defaults=[[0.], [0.]])
inputs = tf.stack([input1])
output = tf.stack([output])

batch_size=4 # [4]
inputs_batch, output_batch = tf.train.shuffle_batch([inputs, output], batch_size, capacity=40, min_after_dequeue=batch_size)

input_ph = tf.placeholder("float", [None,1])
output_ph = tf.placeholder("float",[None,1])

## network
hiddens = slim.stack(input_ph, slim.fully_connected, [7,7,7], 
  activation_fn=tf.nn.sigmoid, scope="hidden")
# prediction = slim.fully_connected(hiddens, 1, activation_fn=tf.nn.sigmoid, scope="output")
prediction = slim.fully_connected(hiddens, 1, activation_fn=None, scope="output")
loss = tf.contrib.losses.mean_squared_error(prediction, output_ph)

train_op = slim.learning.create_train_op(loss, tf.train.AdamOptimizer(0.001))

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  try:
    sess.run(init_op)
    for i in range(30000): #[10000]
      inpbt, outbt = sess.run([inputs_batch, output_batch])
      _, t_loss = sess.run([train_op, loss], feed_dict={input_ph:inpbt, output_ph: outbt})

      if (i+1) % 100 == 0:
        print("%d,%f" % (i+1, t_loss))

    # output to npy 
    model_variables = slim.get_model_variables()
    res = sess.run(model_variables)
    np.save('model_variables.npy', res)

  finally:
    coord.request_stop()


# output trained curve
  print('output') # used to separate from above lines (grep -A 200 output [outfile])
  for loop in range(10):
    inpbt, outbt = sess.run([inputs_batch, output_batch])
    pred = sess.run([prediction], feed_dict={input_ph:inpbt, output_ph: outbt})
    for din,dout in zip(inpbt, pred[0]):
      print('%.5f,%.5f' % (din,dout))

  coord.join(threads)

Screenshot from 2017-07-08 09-08-32.png

正常動作しているようだ。

これを元にTFRecords対応版を作り、TFRecordsの使い方を学ぶ。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?