LoginSignup
0
0

More than 5 years have passed since last update.

ADDA | TensorFlow > TFRecordsへのIntField-Y格納と試し読み v0.1 > 3次元のinput node

Last updated at Posted at 2017-07-09
動作環境
GeForce GTX 1070 (8GB)
ASRock Z170M Pro4S [Intel Z170chipset]
Ubuntu 16.04 LTS desktop amd64
TensorFlow v1.1.0
cuDNN v5.1 for Linux
CUDA v8.0
Python 3.5.2
IPython 6.0.0 -- An enhanced Interactive Python.
gcc (Ubuntu 5.4.0-6ubuntu1~16.04.4) 5.4.0 20160609
GNU bash, version 4.3.48(1)-release (x86_64-pc-linux-gnu)

関連: ADDA > convertToInputcsv_170422.py > TensorFlow用のファイル(input.csv)に変換 > v0.1
関連: TensorFlow / ADDA > 線形方程式の初期値用データの学習 > 学習コード:v0.4 (Exr, Exi, Eyr, Eyi, Ezr, Eziの全学習)

概要

This article is related to ADDA (light scattering simulator based on the discrete dipole approximation).

ADDAの計算で重要となるのが、X,Y,Z方向の電場の値。ランダムな初期値を用いると計算が遅く、最終の解に近い初期値を用いると計算が早くなることは経験済。

supercomputerで計算した最終解を元にDeep learningで学習を行い、その結果を通常のPCで用いる。そうすることで、通常のPC上での計算を高速化し、Communityとしての計算資源の効率利用を目論んでいる。

X,Y,Z方向の電場の値をTensorFlowで学習させようとしている。

X,Y,Zの3次元入力データに対しての学習は以下で実施した。
http://qiita.com/7of9/items/6c5959c786851bce3e76
これを5次元に拡張できないか検討中。

TFRecordsを使うのが良いように思う。

code

対象ファイル

以下のような形式のファイル

$ head IntField-Y
x y z |E|^2 Ex.r Ex.i Ey.r Ey.i Ez.r Ez.i
-0.2362099733 -1.653469813 -5.905249332 0.3808724401 -0.01643908869 0.005065047872 0.5364061341 0.2932420721 0.03475441278 -0.07514132662
0.2362099733 -1.653469813 -5.905249332 0.3808724401 0.01643908869 -0.005065047872 0.5364061341 0.2932420721 0.03475441278 -0.07514132662
-1.181049866 -1.181049866 -5.905249332 0.4643931372 -0.06216652936 0.02425352236 0.6162560583 0.2767004228 0.02440613753 -0.05486267928
-0.7086299199 -1.181049866 -5.905249332 0.3684482195 -0.07402580779 0.006390743272 0.5140943697 0.3061227037 0.02715301914 -0.06470039038
-0.2362099733 -1.181049866 -5.905249332 0.4059533367 -0.03451368538 -0.002394051725 0.5213408941 0.3571248095 0.03088293921 -0.06684457553
0.2362099733 -1.181049866 -5.905249332 0.4059533367 0.03451368538 0.002394051725 0.5213408941 0.3571248095 0.03088293921 -0.06684457553
0.7086299199 -1.181049866 -5.905249332 0.3684482195 0.07402580779 -0.006390743272 0.5140943697 0.3061227037 0.02715301914 -0.06470039038
1.181049866 -1.181049866 -5.905249332 0.4643931372 0.06216652936 -0.02425352236 0.6162560583 0.2767004228 0.02440613753 -0.05486267928
-1.181049866 -0.7086299199 -5.905249332 0.5691917417 -0.06618972606 0.005431362716 0.6710482671 0.3357264532 0.01676364612 -0.03849823498
...

TFRecords保存 v0.1

toTFRecords_InitFieldY_170709.py
import numpy as np
import tensorflow as tf

"""
v0.1 Jul. 09, 2017
   - save as [TFRecords] format
      + add convert_to_raw()
      + add [OUT_FILE]
      + add _int64_feature()
      + add _bytes_feature()
=== branched to [IntFieldY_to_TFRecords_170709.py] ===
v0.1 Apr. 22, 2017 (convertToInputcsv_170422.py)
   - read 'IntField' then output for TensorFlow
"""

# on
#   Ubuntu 16.04 LTS
#   TensorFlow v1.1
#   Python 3.5.2

# codingrule: PEP8


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def convert_to_raw(orgval):
    wrk_org = np.array(orgval, dtype=np.float32)
    wrk_raw = wrk_org.tostring()
    return wrk_raw

OUT_FILE = 'IntField-Y_170709.tfrecords'
data = np.genfromtxt('IntField-Y', delimiter=' ')

xpos, ypos, zpos = data[1:, 0], data[1:, 1], data[1:, 2]  # dipole location
E2 = data[1:, 3]
Exr, Exi = data[1:, 4], data[1:, 5]  # real and imaginary part of Ex
Eyr, Eyi = data[1:, 6], data[1:, 7]  # real and imaginary part of Ey
Ezr, Ezi = data[1:, 8], data[1:, 9]  # real and imaginary part of Ez

list = zip(xpos, ypos, zpos, Exr, Exi, Eyr, Eyi, Ezr, Ezi)

with tf.python_io.TFRecordWriter(OUT_FILE) as tf_writer:
    for tpl in list:
        xpos, ypos, zpos = tpl[0:3]
        exr, exi, eyr, eyi, ezr, ezi = tpl[3:9]
        # for debug
        print("%s, %s, %s, %s, %s, %s, %s, %s, %s" % tpl)
        #
        example = tf.train.Example(features=tf.train.Features(feature={
            'xpos_raw': _bytes_feature(convert_to_raw(xpos)),
            'ypos_raw': _bytes_feature(convert_to_raw(ypos)),
            'zpos_raw': _bytes_feature(convert_to_raw(zpos)),
            'exr_raw': _bytes_feature(convert_to_raw(exr)),
            'exi_raw': _bytes_feature(convert_to_raw(exi)),
            'eyr_raw': _bytes_feature(convert_to_raw(eyr)),
            'eyi_raw': _bytes_feature(convert_to_raw(eyi)),
            'ezr_raw': _bytes_feature(convert_to_raw(ezr)),
            'ezi_raw': _bytes_feature(convert_to_raw(ezi))
            }))
        tf_writer.write(example.SerializeToString())

試し読み v0.1

toTFRecords_InitFieldY_170709.pyにて保存したファイルが正しいかを確認するためのコード。

fromTFRecords_InitFieldY_170709.py
import numpy as np
import tensorflow as tf

"""
v0.1 Jul. 09, 2017
  - read position and Ex, Ey, Ez
     + add get_feature_float32()
"""

# on
#   Ubuntu 16.04 LTS
#   TensorFlow v1.1
#   Python 3.5.2

# codingrule: PEP8


def get_feature_float32(example, feature_name):
    wrk_raw = (example.features.feature[feature_name]
               .bytes_list
               .value[0])
    wrk_1d = np.fromstring(wrk_raw, dtype=np.float32)
    wrk_org = wrk_1d.reshape([1, -1])
    return wrk_org

INP_FILE = 'IntField-Y_170709.tfrecords'

record_iterator = tf.python_io.tf_record_iterator(path=INP_FILE)
for record in record_iterator:
    example = tf.train.Example()
    example.ParseFromString(record)

    xpos_org = get_feature_float32(example, 'xpos_raw')
    ypos_org = get_feature_float32(example, 'ypos_raw')
    zpos_org = get_feature_float32(example, 'zpos_raw')
    exr_org = get_feature_float32(example, 'exr_raw')
    exi_org = get_feature_float32(example, 'exi_raw')
    eyr_org = get_feature_float32(example, 'eyr_raw')
    eyi_org = get_feature_float32(example, 'eyi_raw')
    ezr_org = get_feature_float32(example, 'ezr_raw')
    ezi_org = get_feature_float32(example, 'ezi_raw')

    list_pos = *xpos_org, *ypos_org, *zpos_org
    list_e = *exr_org, *exi_org, *eyr_org, *eyi_org, *ezr_org, *ezi_org

    print(*list_pos, *list_e)

実行

$ python3 toTFRecords_InitFieldY_170709.py > out.170709_1100.org 
$ head out.170709_1100.org 
-0.2362099733, -1.653469813, -5.905249332, -0.01643908869, 0.005065047872, 0.5364061341, 0.2932420721, 0.03475441278, -0.07514132662
0.2362099733, -1.653469813, -5.905249332, 0.01643908869, -0.005065047872, 0.5364061341, 0.2932420721, 0.03475441278, -0.07514132662
-1.181049866, -1.181049866, -5.905249332, -0.06216652936, 0.02425352236, 0.6162560583, 0.2767004228, 0.02440613753, -0.05486267928
-0.7086299199, -1.181049866, -5.905249332, -0.07402580779, 0.006390743272, 0.5140943697, 0.3061227037, 0.02715301914, -0.06470039038
-0.2362099733, -1.181049866, -5.905249332, -0.03451368538, -0.002394051725, 0.5213408941, 0.3571248095, 0.03088293921, -0.06684457553
0.2362099733, -1.181049866, -5.905249332, 0.03451368538, 0.002394051725, 0.5213408941, 0.3571248095, 0.03088293921, -0.06684457553
0.7086299199, -1.181049866, -5.905249332, 0.07402580779, -0.006390743272, 0.5140943697, 0.3061227037, 0.02715301914, -0.06470039038
1.181049866, -1.181049866, -5.905249332, 0.06216652936, -0.02425352236, 0.6162560583, 0.2767004228, 0.02440613753, -0.05486267928
-1.181049866, -0.7086299199, -5.905249332, -0.06618972606, 0.005431362716, 0.6710482671, 0.3357264532, 0.01676364612, -0.03849823498
-0.7086299199, -0.7086299199, -5.905249332, -0.04588795234, 0.004803621663, 0.5436149295, 0.3728649763, 0.02019377811, -0.04343718622
$ python3 fromTFRecords_InitFieldY_170709.py  | head
[-0.23620997] [-1.6534698] [-5.90524912] [-0.01643909] [ 0.00506505] [ 0.53640616] [ 0.29324207] [ 0.03475441] [-0.07514133]
[ 0.23620997] [-1.6534698] [-5.90524912] [ 0.01643909] [-0.00506505] [ 0.53640616] [ 0.29324207] [ 0.03475441] [-0.07514133]
[-1.18104982] [-1.18104982] [-5.90524912] [-0.06216653] [ 0.02425352] [ 0.61625606] [ 0.27670044] [ 0.02440614] [-0.05486268]
[-0.70862991] [-1.18104982] [-5.90524912] [-0.07402581] [ 0.00639074] [ 0.51409435] [ 0.30612269] [ 0.02715302] [-0.06470039]
[-0.23620997] [-1.18104982] [-5.90524912] [-0.03451369] [-0.00239405] [ 0.52134091] [ 0.35712481] [ 0.03088294] [-0.06684458]
[ 0.23620997] [-1.18104982] [-5.90524912] [ 0.03451369] [ 0.00239405] [ 0.52134091] [ 0.35712481] [ 0.03088294] [-0.06684458]
[ 0.70862991] [-1.18104982] [-5.90524912] [ 0.07402581] [-0.00639074] [ 0.51409435] [ 0.30612269] [ 0.02715302] [-0.06470039]
[ 1.18104982] [-1.18104982] [-5.90524912] [ 0.06216653] [-0.02425352] [ 0.61625606] [ 0.27670044] [ 0.02440614] [-0.05486268]
[-1.18104982] [-0.70862991] [-5.90524912] [-0.06618973] [ 0.00543136] [ 0.67104828] [ 0.33572644] [ 0.01676365] [-0.03849823]
[-0.70862991] [-0.70862991] [-5.90524912] [-0.04588795] [ 0.00480362] [ 0.54361492] [ 0.37286496] [ 0.02019378] [-0.04343719]
Traceback (most recent call last):
  File "fromTFRecords_InitFieldY_170709.py", line 46, in <module>
    print(*list_pos, *list_e)
BrokenPipeError: [Errno 32] Broken pipe

うまくいっているようだ。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0