LoginSignup
1
2

More than 5 years have passed since last update.

TensorFlow > sine curveの学習 > weightとbiasをnumpyファイル出力

Last updated at Posted at 2016-12-05
動作環境
GeForce GTX 1070 (8GB)
ASRock Z170M Pro4S [Intel Z170chipset]
Ubuntu 14.04 LTS desktop amd64
TensorFlow v0.11
cuDNN v5.1 for Linux
CUDA v8.0
Python 2.7.6
IPython 5.1.0 -- An enhanced Interactive Python.

関連 http://qiita.com/7of9/items/b364d897b95476a30754

sine curveを学習した時のweightとbiasをもとに自分でネットワークを再現して出力を計算しようとしている。

np.save()を使うことになるだろうか。
http://qiita.com/7of9/items/c730990479687ec2e959

input.csv生成

書出しcode

linreg2_reprod.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

# ファイル名の Queue を作成
filename_queue = tf.train.string_input_producer(["input.csv"])

# CSV を parse
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
input1, output = tf.decode_csv(value, record_defaults=[[0.], [0.]])
inputs = tf.pack([input1])
output = tf.pack([output])

batch_size=4 # [4]
inputs_batch, output_batch = tf.train.shuffle_batch([inputs, output], batch_size, capacity=40, min_after_dequeue=batch_size)

input_ph = tf.placeholder("float", [None,1])
output_ph = tf.placeholder("float",[None,1])

## NN のグラフ生成
hiddens = slim.stack(input_ph, slim.fully_connected, [1,7,7,7], 
  activation_fn=tf.nn.sigmoid, scope="hidden")
prediction = slim.fully_connected(hiddens, 1, activation_fn=tf.nn.sigmoid, scope="output")
loss = tf.contrib.losses.mean_squared_error(prediction, output_ph)

train_op = slim.learning.create_train_op(loss, tf.train.AdamOptimizer(0.001))

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  try:
    sess.run(init_op)
    for i in range(30000): #[10000]
      inpbt, outbt = sess.run([inputs_batch, output_batch])
      _, t_loss = sess.run([train_op, loss], feed_dict={input_ph:inpbt, output_ph: outbt})

      if (i+1) % 100 == 0:
        print("%d,%f" % (i+1, t_loss))

    # output to npy 
    model_variables = slim.get_model_variables()
    res = sess.run(model_variables)
    np.save('model_variables.npy', res)

  finally:
    coord.request_stop()

  coord.join(threads)

実行するとmodel_variables.npyというファイルが生成された。

読込みcode

read_model_var.py
import numpy as np

model_var = np.load('model_variables.npy')
print (model_var)
実行
$python read_model_var.py
[array([[-5.22224426]], dtype=float32) array([ 1.78673065], dtype=float32)
 array([[-4.58573723, -4.27450418, -3.75889063,  4.51949883, -4.02780342,
        -4.10122681,  4.0842309 ]], dtype=float32)
 array([ 0.97008353,  0.70625514,  0.27048966, -0.83405548,  0.57475132,
        0.64893931, -0.45576799], dtype=float32)
 array([[ 1.84565568,  1.49943328,  1.65000439,  2.44512415,  2.09082389,
         2.37278032,  1.65048397],
       [ 1.4404825 ,  0.94583297,  1.89746273,  2.51769876,  1.5209198 ,
         1.94879484,  1.19026875],
       [ 1.49121964,  1.18247831,  1.51108956,  1.18311214,  0.88615912,
         1.1377635 ,  1.33165669],
       [-2.37849569,  0.55434752, -2.70425415, -2.46672988, -2.60600066,
        -2.62578273, -2.78632236],
       [ 0.82718104,  1.42030048,  1.06626236,  1.53540218,  1.55356538,
         1.84659779,  1.25057125],
       [ 1.30038536,  1.45610416,  1.77369738,  2.22379041,  2.24454832,
         2.2828269 ,  1.8470453 ],
       [-2.0469408 ,  1.2918936 , -1.87940514, -2.31857991, -2.2989893 ,
        -2.25665474, -1.73469198]], dtype=float32)
 array([-1.1732285 ,  0.97712201, -1.06408942, -0.75335336, -0.74013597,
       -0.65020418, -1.12149072], dtype=float32)
 array([[ 0.9464646 , -0.11631355, -0.16895044,  0.36979192,  1.05458641,
         0.76118785, -0.4746716 ],
       [-1.47142816, -0.67341083,  0.86623627, -1.63780856, -1.4503684 ,
        -0.71064085,  1.02097607],
       [ 1.58198178,  1.17400146, -1.40148842,  0.5268032 ,  1.45251906,
         0.28084767, -0.8032999 ],
       [ 1.61421323,  1.7214433 , -1.26047742,  1.9050349 ,  1.21235812,
         1.24344146, -2.05446649],
       [ 0.47106522,  1.23160982, -1.13437021,  1.25667596,  0.68895262,
         1.86171079, -0.76870179],
       [ 1.17536426,  1.74713993, -1.9202348 ,  1.397084  ,  1.61537564,
         2.0741148 , -1.37128556],
       [ 0.85221511,  0.6925481 , -0.62384129,  1.12779391,  0.73884082,
         0.08157811, -1.01910996]], dtype=float32)
 array([-1.03282344, -1.26972938,  1.24813604, -1.1338433 , -1.21751046,
       -0.94728774,  1.32416117], dtype=float32)
 array([[-2.49866319],
       [-1.76430416],
       [ 1.60942447],
       [-2.18569684],
       [-2.33169866],
       [-1.27172542],
       [ 2.39537191]], dtype=float32)
 array([-0.45804733], dtype=float32)]

読めていそうだ。

これで学習を繰り返さずに、学習結果を使った処理を繰り返し試すことができる。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2