LoginSignup
32
28

More than 3 years have passed since last update.

tf.dataの使い方メモ

Last updated at Posted at 2020-04-17

大きいファイルを扱う為にtf.dataを勉強した内容をまとめます

以下のサイトを参考にしました
https://qiita.com/Suguru_Toyohara/items/820b0dad955ecd91c7f3
https://qiita.com/wasnot/items/9b64550237a3c5267bfd
https://qiita.com/everylittle/items/a7c31b08d2f76c886a92

tf.dataとは

tensorflowのデータ供給に関するライブラリです。
使うと以下のようなメリットがあるそうです。

  1. GPUの待ち時間を減らして学習速度を最大化できる
  2. メモリに乗り切らないデータを逐次読み込みできる
  3. data augumentationなどの前処理を高速化できる
  4. 前処理パイプラインをまとめられる

1. numpy.arrayを変換して使う

変換する

np.arrayからtf.dataオブジェクトに変換すれば使えます

python
import numpy as np
import tensorflow as tf

arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr)

for item in dataset:
    print(item)
output
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)

repeat

引数回繰り返して出力します

python
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(3)

for item in dataset:
    print(item)
output
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)

batch

引数ずつbatchにして出力します

python
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).batch(2)

for item in dataset:
    print(item)
output
tf.Tensor(
[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

tf.Tensor(
[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)

shuffle

引数でいくつまで遠くのデータと入れ替えるかを指定します。引数が1だと入れ替えがなくなりますし、小さい値だと充分shuffleされないので、データサイズと同じ値を入れるのが良いと思います。

shuffleサイズについてはこちらが詳しいです
https://qiita.com/exy81/items/d1388f6f02a11c8f1d7e

python
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).shuffle(5)

for item in dataset:
    print(item)
output
tf.Tensor(
[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

tf.Tensor(
[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)

組み合わせ

上記を組み合わせて使えます。順に実行されるのでbatchを切ってからshuffleするという無意味なことをしないよう注意しましょう。

python
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(2).shuffle(5).batch(4)

for item in dataset:
    print(item)
    print()
output
tf.Tensor(
[[15 16 17 18 19]
 [ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

tf.Tensor(
[[20 21 22 23 24]
 [ 0  1  2  3  4]
 [20 21 22 23 24]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

tf.Tensor(
[[15 16 17 18 19]
 [ 5  6  7  8  9]], shape=(2, 5), dtype=int32)

argumentationする

dataset.map()で関数を適用する事ができます。

適用する関数はtensorflow関数で構成されるのが望ましいですが、普通に書いた関数を@tf.functionとtf.py_functionで変換して使う事もできるようです。

python
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage

def rotate(image):
    return ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)

@tf.function
def rotate_tf(image):
    rotated = tf.py_function(rotate,[image],[tf.int32])
    return rotated[0]

[train_x, train_y], [test_x, test_y] =  tf.keras.datasets.mnist.load_data()
train_x = train_x.reshape(-1,28,28,1)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.map(rotate_tf).batch(16)

first_batch = next(iter(dataset))
images = first_batch.numpy().reshape((-1,28,28))

plt.figure(figsize=(4, 4))
for i, image in enumerate(sample_images):
    plt.subplot(4, 4,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(image)
    plt.grid(False)
plt.show()

image.png

x, yをまとめてdatasetにする

複数のデータをまとめて1つのdatasetにする事もできます

python
def make_model():
    tf.keras.backend.clear_session()

    inputs = tf.keras.layers.Input(shape=(28, 28))
    network = tf.keras.layers.Flatten()(inputs)
    network = tf.keras.layers.Dense(100, activation='relu')(network)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(network)

    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
    model.summary()
    return model

[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(x_train.shape[0]).batch(64)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(x_test.shape[0]).batch(64)

model = make_model()
hist = model.fit(train_data, validation_data=test_data,
                 epochs=10, verbose=False)

plt.figure(figsize=(4,4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

2. シリアライズして使う

データの読み込みを効率的にするにはデータをシリアライズして連続的に読み込める100-200MBのファイルのセットとして保存すると良いそうです。TFRecordを使えば手軽に実現できます。

保存

tf.io.TFRecordWriter()でTFRecordファイルを書き出します

python
[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

def make_example(image, label):
    return tf.train.Example(features=tf.train.Features(feature={
        'x' : tf.train.Feature(float_list=tf.train.FloatList(value=image)),
        'y' : tf.train.Feature(int64_list=tf.train.Int64List(value=label))
    }))

def write_tfrecord(images, labels, filename):
    writer = tf.io.TFRecordWriter(filename)
    for image, label in zip(images, labels):
        ex = make_example(image.ravel().tolist(), [int(label)])
        writer.write(ex.SerializeToString())
    writer.close()

write_tfrecord(x_train, y_train, '../mnist_train.tfrecord')
write_tfrecord(x_test, y_test, '../mnist_test.tfrecord')

image.png
ファイルサイズはnpzよりやや大きくなるようです。

読み込み(1レコードずつ)

tf.data.TFRecordDataset()で読み込みます

読んだデータはシリアライズされてるのでparseしないといけません。以下の例ではtf.io.parse_single_example()でparseしています。書き込んだときと同じkeyで呼んでreshapeすればシリアライズ前のtf.dataと同様になります。

def parse_features(example):
    features = tf.io.parse_single_example(example, features={
        'x' : tf.io.FixedLenFeature([28, 28], tf.float32),
        'y' : tf.io.FixedLenFeature([1], tf.int64),
    })
    x = features['x']
    y = features['y']
    return x, y

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord')
train_dataset = train_dataset.map(parse_features).shuffle(60000).batch(512)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.map(parse_features).shuffle(12000).batch(512)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

読み込み(batch単位)

じつはtf.io.parse_single_example()で1レコードずつparseするより、batch単位でparseした方が速くなりますので、batch単位でparseする方がオススメです。

python
def dict2tuple(feat):
    return feat["x"], feat["y"]

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord').batch(512).apply(
                    tf.data.experimental.parse_example_dataset({
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.batch(512).apply(
                    tf.data.experimental.parse_example_dataset({
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

処理時間

mnistのデータを同じモデルで学習した場合の処理時間を測ってみました。やはり1レコードずつparseするとかなり遅くなってしまうようです。batch単位で処理すればオンメモリのtf.dataと同等の速度が得られるので、かなり高速だと言えそうです。

また、ここではnumpy.arrayそのままが一番速いという結果になってるんですが実務でやってるとtf.dataの方が明らかに速いんで、オンメモリだとnumpy.arrayが速いってわけではないと思います。ご自身の環境でもいろいろ試して頂ければ幸いです。

32
28
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
32
28