0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlow / ADDA > reproduce_170429.py > v0.1 > 学習したネットワークをもとに、入力データを与えてpredictionに相当するものを出力

Last updated at Posted at 2017-04-29
動作環境
GeForce GTX 1070 (8GB)
ASRock Z170M Pro4S [Intel Z170chipset]
Ubuntu 14.04 LTS desktop amd64
TensorFlow v0.11
cuDNN v5.1 for Linux
CUDA v8.0
Python 2.7.6
IPython 5.1.0 -- An enhanced Interactive Python.
gcc (Ubuntu 4.8.4-2ubuntu1~14.04.3) 4.8.4
GNU bash, version 4.3.8(1)-release (x86_64-pc-linux-gnu)

概要

http://qiita.com/7of9/items/09262a2ab01d037d169b
にて学習したネットワークをもとに、入力データを与えてpredictionに相当するものを出力する。

後々の画像化を考慮して、入力データのファイル(input.csv)と同じ形式で出力するようにした。

code v0.1

reproduce_170429.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import math
import sys

'''
v0.1 Apr. 29, 2017
	- output by replacing the learned [Exr]
    - branched from [check_resultmap_170329.ipynb]
'''

# to ON/OFF debug print at one place
def output_debugPrint(str):
    #print(str)
    pass  # no operation

def calc_sigmoid(x):
    return 1.0 / (1.0 + math.exp(-x))


def calc_conv(src, weight, bias, applyActFnc):
    wgt = weight.shape
#   print wgt # debug
    #conv = list(range(bias.size))
    conv = [0.0] * bias.size

    # weight
    for idx1 in range(wgt[0]):
        for idx2 in range(wgt[1]):
            conv[idx2] = conv[idx2] + src[idx1] * weight[idx1, idx2]
    # bias
    for idx2 in range(wgt[1]):
        conv[idx2] = conv[idx2] + bias[idx2]
    # activation function
    if applyActFnc:
        for idx2 in range(wgt[1]):
            conv[idx2] = calc_sigmoid(conv[idx2])
    return conv  # return list

model_var = np.load('model_variables_170429.npy')
learn_data = np.loadtxt('input.csv', delimiter=',')

prd_data = np.array([])

for idx, items in enumerate(learn_data):
    if len(items) < 1:  # avoid empty line
        break
    #if idx >= 5:  # for debug
    #    break
    print(idx)
    # input layer (3 node)
    inlist = (items[0], items[1], items[2])
    # hidden layer 1
    outdata = calc_conv(inlist, model_var[0], model_var[1], applyActFnc=True)
    # hidden layer 2
    outdata = calc_conv(outdata, model_var[2], model_var[3], applyActFnc=True)
    # hidden layer 3
    outdata = calc_conv(outdata, model_var[4], model_var[5], applyActFnc=True)
    # output layer
    outdata = calc_conv(outdata, model_var[6], model_var[7], applyActFnc=False)
    
    wrk = [ 
    	   items[0], items[1], items[2],  # x,y,z
    	   outdata[0],  items[4],  # Exr, Exi
    	   items[5], items[6],  # Eyr, Eyi
    	   items[7], items[8]   # Ezr, Ezi
    	   ]
    #print(wrk)
    prd_data = np.append(prd_data, wrk)

print('idx=',idx)

NUM_ITEMS = 9
ncols = len(prd_data) / NUM_ITEMS
prd_data = np.array(prd_data).reshape(ncols, NUM_ITEMS)

np.savetxt("input_Exr_replaced_170429.csv", prd_data, delimiter=","
	       , fmt="%.9f")

結果

$ head input_Exr_replaced_170429.csv 
-0.209439510,-1.466076572,-5.235987756,0.068316654,0.113688322,0.519136168,0.356408494,-0.109750412,0.005652638
0.209439510,-1.466076572,-5.235987756,0.088643790,0.075268157,0.806765822,0.313605899,-0.087339272,0.002273401
-1.047197551,-1.047197551,-5.235987756,0.025312938,0.105477041,-0.151919209,0.284581139,-0.040284662,0.014306838
-0.628318531,-1.047197551,-5.235987756,0.040992646,0.090998304,0.132621814,0.389916050,-0.049129976,0.030650874
-0.209439510,-1.047197551,-5.235987756,0.050157650,0.044499152,0.422640054,0.416238663,-0.054321067,0.030099672
0.209439510,-1.047197551,-5.235987756,0.058198942,0.081592287,0.696098762,0.228374544,-0.029275715,0.000805544
0.628318531,-1.047197551,-5.235987756,0.065595694,-0.026843276,0.797857025,0.040918210,-0.023928849,-0.034062430
1.047197551,-1.047197551,-5.235987756,0.065374163,-0.055698835,0.728159227,-0.100947639,-0.034184970,-0.041232617
-1.047197551,-0.628318531,-5.235987756,0.016307764,0.083149606,-0.275568958,0.256085316,-0.040395470,0.006757636
-0.628318531,-0.628318531,-5.235987756,0.027492900,0.041832652,0.087782092,0.394477943,-0.037243314,0.020279789

遅い

numpyをうまく使えていないので、すごく遅い。

9327行の処理で3分かかる。

real	2m57.543s
user	2m56.590s
sys	0m0.485s

numpyの扱いをもっと学ばないといけない。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?