前説
TFRecordのTensorFlow公式チュートリアルでスカラー値の保存しか詳細に解説されていなかったため、
多次元Tensor(元はndarray)の保存方法を備忘録として記す。
numpy行列をTFRecordに保存し、さらにそれを読み込みtf.data.Datasetにするまでの手順をまとめていく。
また、簡単にsharding(複数ファイル分割)を行う方法も記載している。
サンプルデータ
numpyを使って保存/読み込みをおこなうデータを用意。
import numpy as np
length = 3
tmp_array = np.empty((0, length, length))
for i in range(4):
tmp_array = np.append(tmp_array, (np.identity(3)*(i+1))[np.newaxis, :, :], axis=0)
tmp_array
TFRecord 書き出し
tf.io.serialize_tensor
を使ってTensor行列をバイト列に変換し、TFRecordとして書き出す。
shard_numで定義した分割数でTFRecordを自動的に分割して保存する。
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def to_example(data):
tmp_array = data
feature = {
'tmp': _bytes_feature(tf.io.serialize_tensor(tmp_array)),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
tmp_dataset = tf.data.Dataset.from_tensor_slices(tmp_array)
record_dir = "./shard_{}.records".format("{:03d}")
shard_num = 2
for i in range(shard_num):
tfrecords_shard_path = os.path.join(record_dir.format(i))
shard_data = tmp_dataset.shard(shard_num, i)
with tf.io.TFRecordWriter(tfrecords_shard_path) as writer:
for data in shard_data:
tf_example = to_example(data)
writer.write(tf_example.SerializeToString())
TFRecord 読み込み
parse_example
の中身が鍵。
TFRecordの中身をtf.io.parse_example
によりバイト列(tf.string)へ復元し、
tf.io.parse_tensor
によってバイト列からテンソルへ復元する。
record_dir = "./shard_{}.records".format("*")
shard_files = tf.io.matching_files(record_dir)
def parse_example(example_proto):
feature_description = {
'tmp': tf.io.FixedLenFeature([], tf.string, default_value=''),
}
parsed_elem = tf.io.parse_example(example_proto, feature_description)
for key in feature_description.keys():
parsed_elem[key] = tf.io.parse_tensor(parsed_elem[key], out_type=tf.float64)
return list(parsed_elem.values())
shards = tf.data.Dataset.from_tensor_slices(shard_files)
shuffle_buffer_Size = 4
shards = shards.shuffle(shuffle_buffer_Size)
dataset = shards.interleave(tf.data.TFRecordDataset, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.map(map_func=parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
for i, datas in enumerate(dataset):
tmp = datas
print(tmp)
注意点
ndarrayとtensorのdtypeは別で定義したり、記憶するなりしておく必要がある。
記載のコードではベタ書きだがdictやlistにまとめておくのが吉。