動作環境
GeForce GTX 1070 (8GB)
ASRock Z170M Pro4S [Intel Z170chipset]
Ubuntu 14.04 LTS desktop amd64
TensorFlow v0.11
cuDNN v5.1 for Linux
CUDA v8.0
Python 2.7.6
IPython 5.1.0 -- An enhanced Interactive Python.
sine curveを近似する。
データ生成部
他の関数を近似する予定なので、データは別のコードで生成する。
prep_data.py
import numpy as np
import random
numdata=100
x_data = np.random.rand(numdata)
y_data = np.sin(2*np.pi*x_data) + 0.3 * np.random.rand()
for xs, ys in zip(x_data, y_data):
print '%.5f, %.5f' % (xs, ys)
csvファイル
$python prep_data.py > input.csv
input.csv(例)
0.74597, -0.91122
0.33339, 0.95432
0.03281, 0.29314
0.49378, 0.12754
0.59515, -0.47443
0.19094, 1.02040
0.04446, 0.36420
0.02983, 0.27479
...
学習コード
参考 http://qiita.com/learn_tensorflow/items/3e46b2512a1bab73f5b2
上記のtrain.pyをベースとした。
以下の変更をしている。
- input.csvの形式を変更
- hiddensの定義を[2,2]から[1,7,7,7]に変更
- print書式を変更
linreg2.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorflow.contrib.slim as slim
# ファイル名の Queue を作成
filename_queue = tf.train.string_input_producer(["input.csv"])
# CSV を parse
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
input1, output = tf.decode_csv(value, record_defaults=[[0.], [0.]])
inputs = tf.pack([input1])
output = tf.pack([output])
# 4サンプル毎の Batch 化
inputs_batch, output_batch = tf.train.shuffle_batch([inputs, output], 4, capacity=40, min_after_dequeue=4)
## NN のグラフ生成
hiddens = slim.stack(inputs_batch, slim.fully_connected, [1,7,7,7],
activation_fn=tf.nn.sigmoid, scope="hidden")
prediction = slim.fully_connected(hiddens, 1, activation_fn=tf.nn.sigmoid, scope="output")
loss = tf.contrib.losses.mean_squared_error(prediction, output_batch)
train_op = slim.learning.create_train_op(loss, tf.train.AdamOptimizer(0.01))
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
sess.run(init_op)
for i in range(10000):
_, t_loss = sess.run([train_op, loss])
if (i+1) % 100 == 0:
print("%d,%f" % (i+1, t_loss)) #step, loss
finally:
coord.request_stop()
coord.join(threads)
$ python linreg2.py
...
8200 0.197571 #step, loss
8300 0.197723 #step, loss
8400 0.369575 #step, loss
8500 0.014792 #step, loss
8600 0.115452 #step, loss
8700 0.013261 #step, loss
8800 0.078891 #step, loss
8900 0.350682 #step, loss
9000 0.349131 #step, loss
9100 0.167577 #step, loss
9200 0.141080 #step, loss
9300 0.211734 #step, loss
9400 0.413739 #step, loss
9500 0.133878 #step, loss
9600 0.293956 #step, loss
9700 0.449871 #step, loss
9800 0.355519 #step, loss
9900 0.020344 #step, loss
10000 0.198887 #step, loss
loss for step
$python linreg2.py > log.learn
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
data = np.loadtxt('log.learn', delimiter=',')
input1 = data[:,0]
output = data[:,1]
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
x = np.linspace(-6,6,1000)
ax.plot(input1, output, color='black', linestyle='solid')
ax.set_title('loss')
ax.set_xlabel('step')
ax.set_ylabel('loss')
ax.grid(True)
fig.show()
収束する気配がない。
ネットの定義など検討課題は多い。