2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

[TensorFlow] TFRecordから任意の型の可変長データをバッチで読み込んで高速化

Posted at

検証環境

Google Colab (CPU)
Python 3.10.11
TensorFlow 2.12.0

前置き

メモリに入り切らないような大量のデータは、TFRecord形式のファイルに書き込んでおくと効率よくデータの読み込みと学習ができます。
このとき、レコード単位ではなくバッチ単位でデータをデコードするように書けば、読み込みをより効率化できます。CPUパワーを消費しまくっていてGPUがあまり動いていないという場合、このデータのデコード処理がボトルネックになっているかもしれません。

バッチの各レコードがすべて同じ長さであれば話は簡単なのですが、可変長の特徴量を扱いたい場合も考えられます。自然言語処理なら各レコードが単語ID列、画像処理なら各レコードがピクセル列、といった具合です。
このような場合にバッチ単位で効率よくデコードする方法を考えてみます。

TFRecordの作り方

以下のコードで、異なる長さのint32列を1万件作成し、data.tfrecord というファイルに書き込んでいるものとします。

import numpy as np
import tensorflow as tf

def make_example(x):
    return tf.train.Example(features=tf.train.Features(feature={
        "x": tf.train.Feature(bytes_list=tf.train.BytesList(value=[x.tobytes()]))
    }))

filename_test = "data.tfrecord"

rnd = np.random.default_rng(seed=1234)
lengths = rnd.integers(low=100, high=200, size=(10000,)) # 長さ100-199のデータを1万件
with tf.io.TFRecordWriter(filename_test) as writer:
  for l in lengths:
    # 先ほど決めた長さのデータをランダムに生成
    ex = make_example(rnd.integers(low=0, high=100000, size=(l,), dtype=np.int32))
    writer.write(ex.SerializeToString())

データ型が int64 または float32 の場合

tf.io.parse_example() メソッドが使えます。

def parse_batch_example(example): 
    features = tf.io.parse_example(example, features={ 
        "x": tf.io.RaggedFeature(tf.float32)
    }) 
    x = features["x"] 
    return x

ds = tf.data.TFRecordDataset(["data.tfrecord"]) \
          .batch(512) \
          .map(parse_batch_example)

しかしこの方法、読み込めるデータ型に制限があります。
tf.io.RaggedFeature()リファレンスを見ると

Fields
dtype Data type of the RaggedTensor. Must be one of: tf.dtypes.int64, tf.dtypes.float32, tf.dtypes.string.

とあります。データ型が int64, float32, string に制限されているのです。
例えば画像データであれば各ピクセルは tf.uint8 かもしれませんし、単語列であれば単語IDに64ビットも使うのはもったいないので tf.int32 で管理したい、ということがあると思います。

このように、他のデータ型の場合でもバッチ単位でデータを効率よく読み込み、バッチごとに RaggedTensor を得ることを目指します。どうすればよいでしょうか。

他のデータ型の場合

int64, float32以外のデータを扱う場合、TFRecordには配列の値をバイト列にエンコードしたものを書き込み、読み込む時には string(バイト列)として読み込んでデコードします。

ダメな方法

tf.io.decode_raw() を使ってバイナリデータを任意の型のTensorに変換することができますが、バッチ単位でデコードする場合はすべてのレコードが同じ長さでないといけません。

def parse_batch_example(example): 
    features = tf.io.parse_example(example, features={ 
        "x": tf.io.FixedLenFeature([], dtype=tf.string)
    }) 
    x = tf.io.decode_raw(features["x"], tf.int32) # すべて同じ長さでないと失敗する!
    return x

動作する方法

一旦レコードを tf.string でバイト列のバッチとして読み込んだ後、バッチ内のレコードをすべて結合すれば、 tf.io.decode_raw() でバッチ内のレコードをまとめてデコードすることが可能です。
各レコードの長さ情報が失われるので、元のレコードの長さに合わせてレコードを再分割する必要がありますが、tf.RaggedTensor.from_row_lengths() を使えば与えられた長さで行を分割した RaggedTensor を作成することができます。

def parse_batch_example(example): 
    features = tf.io.parse_example(example, features={ 
        "x": tf.io.FixedLenFeature([], dtype=tf.string)
    })
    x_bin = features["x"]
    x_lengths = tf.strings.length(x_bin) # 各レコードの長さ(バイト単位)が得られる
    x_bin_flatten = tf.strings.reduce_join(features["x"]) # 各レコードを連結したバイト列が得られる
    x_decoded_flatten = tf.io.decode_raw(x_bin_flatten, tf.int32) # 各レコードを連結したint32の列が得られる
    x = tf.RaggedTensor.from_row_lengths(x_decoded_flatten, x_lengths // tf.int32.size) # 元のレコードの長さに合わせて分割
    return x

ds = tf.data.TFRecordDataset(["data.tfrecord"]) \
          .batch(512) \
          .map(parse_batch_example)
print(next(iter(ds)))

これで、可変長のレコードをバッチ単位でデコードできました。

パフォーマンスの確認

比較のため、バッチ単位でなくレコード単位でデコードし、後で Dataset.ragged_batch() でバッチ化する場合も試します。データセットの内容は全く同じになります。

def parse_single_example(example): 
    features = tf.io.parse_example(example, features={ 
        "x": tf.io.FixedLenFeature([], dtype=tf.string)
    })
    x_decoded = tf.io.decode_raw(features["x"], tf.int32)
    return x_decoded

ds2 = tf.data.TFRecordDataset(["data.tfrecord"]) \
          .map(parse_single_example) \
          .ragged_batch(512)
print(ds2)
ipython
# バッチ単位でデコードする場合
%timeit [1 for _ in iter(ds)]
72.5 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# レコード単位でデコードする場合
%timeit [1 for _ in iter(ds2)]
587 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

期待通り、バッチ単位で読み込む方法の圧勝となりました!

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?