4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tfds.loadを使ってtf.Tensorを取得する方法

Last updated at Posted at 2019-05-18

背景

TensorFlow Datasets(tfds)のloadはGet started with TensorFlow 2.0 for expertsでも使われているように、今後世の中のデータセットを利用するときによく使うメソッドだと思います。2019年5月18日時点で69種類のデータセットが取得可能です(対してtf.keras.datasetsは7種類)。

しかし、tfds.loadはデフォルトではtf.data.Datasetを返すのですが、これでは困る場合があります。例えば、tf.keras.preprocessing.image.ImageDataGeneratorのfitやflowの入力はNumpyの4次元配列でなければいけません。それを実現するためには以下のようなコードを書く必要があり、あまり直感的ではありません(よりよい実現方法をご存知の方がいれば是非教えて下さい!)。そのため、そもそもtfds.loadにtf.data.Dataset以外の型を返させたいなぁというのがきっかけです。

dataset_train = dataset_train.batch(info.splits["train"].num_examples)
for train_x, train_y in dataset_train:
    pass

print(train_x.shape, train_y.shape)
# (50000, 32, 32, 3) (50000,)

実現方法

https://www.tensorflow.org/datasets/api_docs/python/tfds/load をよく見ると書いてありますが、batch_sizeに-1を指定すると実現できます。以下サンプルコードです。

import tensorflow_datasets as tfds

dataset, info = tfds.load("cifar10", as_supervised = True, with_info = True, batch_size = -1)

dataset_train, dataset_test = dataset["train"], dataset["test"]
print(dataset_train[0].shape, dataset_train[1].shape)
print(dataset_test[0].shape, dataset_test[1].shape)

# (50000, 32, 32, 3) (50000,)
# (10000, 32, 32, 3) (10000,)
4
1
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?