Edited at

tfds.loadを使ってtf.Tensorを取得する方法


背景

TensorFlow Datasets(tfds)のloadはGet started with TensorFlow 2.0 for expertsでも使われているように、今後世の中のデータセットを利用するときによく使うメソッドだと思います。2019年5月18日時点で69種類のデータセットが取得可能です(対してtf.keras.datasetsは7種類)。

しかし、tfds.loadはデフォルトではtf.data.Datasetを返すのですが、これでは困る場合があります。例えば、tf.keras.preprocessing.image.ImageDataGeneratorのfitやflowの入力はNumpyの4次元配列でなければいけません。それを実現するためには以下のようなコードを書く必要があり、あまり直感的ではありません(よりよい実現方法をご存知の方がいれば是非教えて下さい!)。そのため、そもそもtfds.loadにtf.data.Dataset以外の型を返させたいなぁというのがきっかけです。

dataset_train = dataset_train.batch(info.splits["train"].num_examples)

for train_x, train_y in dataset_train:
pass

print(train_x.shape, train_y.shape)
# (50000, 32, 32, 3) (50000,)


実現方法

https://www.tensorflow.org/datasets/api_docs/python/tfds/load をよく見ると書いてありますが、batch_sizeに-1を指定すると実現できます。以下サンプルコードです。

import tensorflow_datasets as tfds

dataset, info = tfds.load("cifar10", as_supervised = True, with_info = True, batch_size = -1)

dataset_train, dataset_test = dataset["train"], dataset["test"]
print(dataset_train[0].shape, dataset_train[1].shape)
print(dataset_test[0].shape, dataset_test[1].shape)

# (50000, 32, 32, 3) (50000,)
# (10000, 32, 32, 3) (10000,)