大きいファイルを扱う為にtf.dataを勉強した内容をまとめます
以下のサイトを参考にしました
https://qiita.com/Suguru_Toyohara/items/820b0dad955ecd91c7f3
https://qiita.com/wasnot/items/9b64550237a3c5267bfd
https://qiita.com/everylittle/items/a7c31b08d2f76c886a92
tf.dataとは
tensorflowのデータ供給に関するライブラリです。
使うと以下のようなメリットがあるそうです。
- GPUの待ち時間を減らして学習速度を最大化できる
- メモリに乗り切らないデータを逐次読み込みできる
- data augumentationなどの前処理を高速化できる
- 前処理パイプラインをまとめられる
1. numpy.arrayを変換して使う
変換する
np.arrayからtf.dataオブジェクトに変換すれば使えます
import numpy as np
import tensorflow as tf
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr)
for item in dataset:
print(item)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
repeat
引数回繰り返して出力します
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(3)
for item in dataset:
print(item)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
batch
引数ずつbatchにして出力します
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).batch(2)
for item in dataset:
print(item)
tf.Tensor(
[[0 1 2 3 4]
[5 6 7 8 9]], shape=(2, 5), dtype=int32)
tf.Tensor(
[[10 11 12 13 14]
[15 16 17 18 19]], shape=(2, 5), dtype=int32)
tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)
shuffle
引数でいくつまで遠くのデータと入れ替えるかを指定します。引数が1だと入れ替えがなくなりますし、小さい値だと充分shuffleされないので、データサイズと同じ値を入れるのが良いと思います。
shuffleサイズについてはこちらが詳しいです
https://qiita.com/exy81/items/d1388f6f02a11c8f1d7e
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).shuffle(5)
for item in dataset:
print(item)
tf.Tensor(
[[0 1 2 3 4]
[5 6 7 8 9]], shape=(2, 5), dtype=int32)
tf.Tensor(
[[10 11 12 13 14]
[15 16 17 18 19]], shape=(2, 5), dtype=int32)
tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)
組み合わせ
上記を組み合わせて使えます。順に実行されるのでbatchを切ってからshuffleするという無意味なことをしないよう注意しましょう。
arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(2).shuffle(5).batch(4)
for item in dataset:
print(item)
print()
tf.Tensor(
[[15 16 17 18 19]
[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]], shape=(4, 5), dtype=int32)
tf.Tensor(
[[20 21 22 23 24]
[ 0 1 2 3 4]
[20 21 22 23 24]
[10 11 12 13 14]], shape=(4, 5), dtype=int32)
tf.Tensor(
[[15 16 17 18 19]
[ 5 6 7 8 9]], shape=(2, 5), dtype=int32)
argumentationする
dataset.map()で関数を適用する事ができます。
適用する関数はtensorflow関数で構成されるのが望ましいですが、普通に書いた関数を@tf.functionとtf.py_functionで変換して使う事もできるようです。
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
def rotate(image):
return ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
@tf.function
def rotate_tf(image):
rotated = tf.py_function(rotate,[image],[tf.int32])
return rotated[0]
[train_x, train_y], [test_x, test_y] = tf.keras.datasets.mnist.load_data()
train_x = train_x.reshape(-1,28,28,1)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.map(rotate_tf).batch(16)
first_batch = next(iter(dataset))
images = first_batch.numpy().reshape((-1,28,28))
plt.figure(figsize=(4, 4))
for i, image in enumerate(sample_images):
plt.subplot(4, 4,i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(image)
plt.grid(False)
plt.show()
x, yをまとめてdatasetにする
複数のデータをまとめて1つのdatasetにする事もできます
def make_model():
tf.keras.backend.clear_session()
inputs = tf.keras.layers.Input(shape=(28, 28))
network = tf.keras.layers.Flatten()(inputs)
network = tf.keras.layers.Dense(100, activation='relu')(network)
outputs = tf.keras.layers.Dense(10, activation='softmax')(network)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
return model
[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(x_train.shape[0]).batch(64)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(x_test.shape[0]).batch(64)
model = make_model()
hist = model.fit(train_data, validation_data=test_data,
epochs=10, verbose=False)
plt.figure(figsize=(4,4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()
2. シリアライズして使う
データの読み込みを効率的にするにはデータをシリアライズして連続的に読み込める100-200MBのファイルのセットとして保存すると良いそうです。TFRecordを使えば手軽に実現できます。
保存
tf.io.TFRecordWriter()でTFRecordファイルを書き出します
[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()
def make_example(image, label):
return tf.train.Example(features=tf.train.Features(feature={
'x' : tf.train.Feature(float_list=tf.train.FloatList(value=image)),
'y' : tf.train.Feature(int64_list=tf.train.Int64List(value=label))
}))
def write_tfrecord(images, labels, filename):
writer = tf.io.TFRecordWriter(filename)
for image, label in zip(images, labels):
ex = make_example(image.ravel().tolist(), [int(label)])
writer.write(ex.SerializeToString())
writer.close()
write_tfrecord(x_train, y_train, '../mnist_train.tfrecord')
write_tfrecord(x_test, y_test, '../mnist_test.tfrecord')
読み込み(1レコードずつ)
tf.data.TFRecordDataset()で読み込みます
読んだデータはシリアライズされてるのでparseしないといけません。以下の例ではtf.io.parse_single_example()でparseしています。書き込んだときと同じkeyで呼んでreshapeすればシリアライズ前のtf.dataと同様になります。
def parse_features(example):
features = tf.io.parse_single_example(example, features={
'x' : tf.io.FixedLenFeature([28, 28], tf.float32),
'y' : tf.io.FixedLenFeature([1], tf.int64),
})
x = features['x']
y = features['y']
return x, y
train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord')
train_dataset = train_dataset.map(parse_features).shuffle(60000).batch(512)
test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.map(parse_features).shuffle(12000).batch(512)
model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
epochs=10, verbose=False)
plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()
読み込み(batch単位)
じつはtf.io.parse_single_example()で1レコードずつparseするより、batch単位でparseした方が速くなりますので、batch単位でparseする方がオススメです。
def dict2tuple(feat):
return feat["x"], feat["y"]
train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord').batch(512).apply(
tf.data.experimental.parse_example_dataset({
"x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
"y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)
test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.batch(512).apply(
tf.data.experimental.parse_example_dataset({
"x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
"y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)
model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
epochs=10, verbose=False)
plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()
処理時間
mnistのデータを同じモデルで学習した場合の処理時間を測ってみました。やはり1レコードずつparseするとかなり遅くなってしまうようです。batch単位で処理すればオンメモリのtf.dataと同等の速度が得られるので、かなり高速だと言えそうです。
また、ここではnumpy.arrayそのままが一番速いという結果になってるんですが実務でやってるとtf.dataの方が明らかに速いんで、オンメモリだとnumpy.arrayが速いってわけではないと思います。ご自身の環境でもいろいろ試して頂ければ幸いです。