1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

TFRecordでデータセットを組むまとめ

Last updated at Posted at 2022-09-01

前説

TFRecordのTensorFlow公式チュートリアルでスカラー値の保存しか詳細に解説されていなかったため、
多次元Tensor(元はndarray)の保存方法を備忘録として記す。

numpy行列をTFRecordに保存し、さらにそれを読み込みtf.data.Datasetにするまでの手順をまとめていく。
また、簡単にsharding(複数ファイル分割)を行う方法も記載している。

サンプルデータ

numpyを使って保存/読み込みをおこなうデータを用意。

import numpy as np
length = 3
tmp_array = np.empty((0, length, length))
for i in range(4):
  tmp_array = np.append(tmp_array, (np.identity(3)*(i+1))[np.newaxis, :, :], axis=0)
tmp_array

TFRecord 書き出し

tf.io.serialize_tensorを使ってTensor行列をバイト列に変換し、TFRecordとして書き出す。
shard_numで定義した分割数でTFRecordを自動的に分割して保存する。

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def to_example(data):
    tmp_array = data
    feature = {
        'tmp': _bytes_feature(tf.io.serialize_tensor(tmp_array)),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

tmp_dataset = tf.data.Dataset.from_tensor_slices(tmp_array)

record_dir = "./shard_{}.records".format("{:03d}")

shard_num = 2
for i in range(shard_num):
    tfrecords_shard_path = os.path.join(record_dir.format(i))
    
    shard_data = tmp_dataset.shard(shard_num, i)

    with tf.io.TFRecordWriter(tfrecords_shard_path) as writer:
      for data in shard_data:
        tf_example = to_example(data)
        writer.write(tf_example.SerializeToString())

TFRecord 読み込み

parse_exampleの中身が鍵。
TFRecordの中身をtf.io.parse_exampleによりバイト列(tf.string)へ復元し、
tf.io.parse_tensorによってバイト列からテンソルへ復元する。

record_dir = "./shard_{}.records".format("*")

shard_files = tf.io.matching_files(record_dir)

def parse_example(example_proto):
    feature_description = {
        'tmp': tf.io.FixedLenFeature([], tf.string, default_value=''),
    }
    parsed_elem = tf.io.parse_example(example_proto, feature_description)
    for key in feature_description.keys():
        parsed_elem[key] = tf.io.parse_tensor(parsed_elem[key], out_type=tf.float64)
    return list(parsed_elem.values())

shards = tf.data.Dataset.from_tensor_slices(shard_files)
shuffle_buffer_Size = 4
shards = shards.shuffle(shuffle_buffer_Size)
dataset = shards.interleave(tf.data.TFRecordDataset, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.map(map_func=parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    
for i, datas in enumerate(dataset):
    tmp = datas
    print(tmp)

注意点

ndarrayとtensorのdtypeは別で定義したり、記憶するなりしておく必要がある。
記載のコードではベタ書きだがdictやlistにまとめておくのが吉。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?