0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlow > sine curveの学習 > v0.1: ほとんど人のコードを使わせていただいたまま

Last updated at Posted at 2016-11-12
動作環境
GeForce GTX 1070 (8GB)
ASRock Z170M Pro4S [Intel Z170chipset]
Ubuntu 14.04 LTS desktop amd64
TensorFlow v0.11
cuDNN v5.1 for Linux
CUDA v8.0
Python 2.7.6
IPython 5.1.0 -- An enhanced Interactive Python.

sine curveを近似する。

データ生成部

他の関数を近似する予定なので、データは別のコードで生成する。

prep_data.py
import numpy as np
import random

numdata=100
x_data = np.random.rand(numdata)
y_data = np.sin(2*np.pi*x_data) + 0.3 * np.random.rand()

for xs, ys in zip(x_data, y_data):
	print '%.5f, %.5f' % (xs, ys)

csvファイル

$python prep_data.py > input.csv

input.csv(例)
0.74597, -0.91122
0.33339, 0.95432
0.03281, 0.29314
0.49378, 0.12754
0.59515, -0.47443
0.19094, 1.02040
0.04446, 0.36420
0.02983, 0.27479
...

学習コード

参考 http://qiita.com/learn_tensorflow/items/3e46b2512a1bab73f5b2

上記のtrain.pyをベースとした。

以下の変更をしている。

  • input.csvの形式を変更
  • hiddensの定義を[2,2]から[1,7,7,7]に変更
  • print書式を変更
linreg2.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import tensorflow as tf
import tensorflow.contrib.slim as slim

# ファイル名の Queue を作成
filename_queue = tf.train.string_input_producer(["input.csv"])

# CSV を parse
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
input1, output = tf.decode_csv(value, record_defaults=[[0.], [0.]])
inputs = tf.pack([input1])
output = tf.pack([output])

# 4サンプル毎の Batch 化
inputs_batch, output_batch = tf.train.shuffle_batch([inputs, output], 4, capacity=40, min_after_dequeue=4)

## NN のグラフ生成
hiddens = slim.stack(inputs_batch, slim.fully_connected, [1,7,7,7], 
  activation_fn=tf.nn.sigmoid, scope="hidden")
prediction = slim.fully_connected(hiddens, 1, activation_fn=tf.nn.sigmoid, scope="output")

loss = tf.contrib.losses.mean_squared_error(prediction, output_batch)
train_op = slim.learning.create_train_op(loss, tf.train.AdamOptimizer(0.01))

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  try:
    sess.run(init_op)
    for i in range(10000):
      _, t_loss = sess.run([train_op, loss])
      if (i+1) % 100 == 0:
        print("%d,%f" % (i+1, t_loss)) #step, loss
  finally:
    coord.request_stop()

  coord.join(threads)
$ python linreg2.py
...
8200 0.197571 #step, loss
8300 0.197723 #step, loss
8400 0.369575 #step, loss
8500 0.014792 #step, loss
8600 0.115452 #step, loss
8700 0.013261 #step, loss
8800 0.078891 #step, loss
8900 0.350682 #step, loss
9000 0.349131 #step, loss
9100 0.167577 #step, loss
9200 0.141080 #step, loss
9300 0.211734 #step, loss
9400 0.413739 #step, loss
9500 0.133878 #step, loss
9600 0.293956 #step, loss
9700 0.449871 #step, loss
9800 0.355519 #step, loss
9900 0.020344 #step, loss
10000 0.198887 #step, loss

loss for step

$python linreg2.py > log.learn

%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

data = np.loadtxt('log.learn', delimiter=',')
input1 = data[:,0]
output = data[:,1]

fig = plt.figure()
ax = fig.add_subplot(1,1,1)

x = np.linspace(-6,6,1000)

ax.plot(input1, output, color='black', linestyle='solid')

ax.set_title('loss')
ax.set_xlabel('step')
ax.set_ylabel('loss')
ax.grid(True)
fig.show()

qiita.png

収束する気配がない。
ネットの定義など検討課題は多い。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?