1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

stylegan2のdataset_toolでデータセットを作るとデフォルトでシャッフルされる話

Posted at

要約

stylegan2で複数の解像度の画像のデータセット(dataset-r02.tfrecordsみたいなもの)を作るときに

python dataset_tool.py create_from_images ./my/dataset ./my/pic

とやると、デフォルトで順番がシャッフルされる。

シャッフルさせないためには、

python dataset_tool.py create_from_images ./my/dataset ./my/pic --shuffle 0

としなければいけない。

はじめに

昔GANを自分で作って実行したりしていたが、どうもいい画像が作れず、放置していました。
しかし最近styleGAN2なるものを発見していろいろ調べていました。
その中でも

この記事の内容はgooglecolabも公開されているうえ、ちゃんと動くという素晴らしい記事でした。

しかしこちらで、元の画像と生成された画像の順番が違うという問題点があり、その解決法を見つけました。

内容

まず、使用しているstyleganのコードはこちらになります。

そして、複数の解像度の画像のデータセット(dataset-r02.tfrecordsみたいなもの)を作るときに実行する

python dataset_tool.py create_from_images ./my/dataset ./my/pic

について、dataset_tool.pyをみてみると、最後のほうに引数についてこのような記述がありました。

dataset_tool.py
p = add_command(    'create_from_images', 'Create dataset from a directory full of images.',
                                            'create_from_images datasets/mydataset myimagedir')
    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')
    p.add_argument(     'image_dir',        help='Directory containing the images')
    p.add_argument(     '--shuffle',        help='Randomize image order (default: 1)', type=int, default=1)

ここのの--shuffleという引数はデフォルトで1になっていました。

さらにこの--shuffleが使用されているところを探すと、

dataset_tool.py
def create_from_images(tfrecord_dir, image_dir, shuffle):
    print('Loading images from "%s"' % image_dir)
    image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
    if len(image_filenames) == 0:
        error('No input images found')

    img = np.asarray(PIL.Image.open(image_filenames[0]))
    resolution = img.shape[0]
    channels = img.shape[2] if img.ndim == 3 else 1
    if img.shape[1] != resolution:
        error('Input images must have the same width and height')
    if resolution != 2 ** int(np.floor(np.log2(resolution))):
        error('Input image resolution must be a power-of-two')
    if channels not in [1, 3]:
        error('Input images must be stored as RGB or grayscale')

    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
        order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
        for idx in range(order.size):
            img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
            if channels == 1:
                img = img[np.newaxis, :, :] # HW => CHW
            else:
                img = img.transpose([2, 0, 1]) # HWC => CHW
            tfr.add_image(img)

この関数(今回使用している)にたどり着き、さらにこう使われていました。

.py
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
.py
 def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order.
        order = np.arange(self.expected_images)
        np.random.RandomState(123).shuffle(order)
        return order

調べたのですが、pythonではどうもTrueFalse10にも対応しているようで、デフォルトの1シャッフルするに相当していたみたいです。

よってシャッフルされないためには、

python dataset_tool.py create_from_images ./my/dataset ./my/pic --shuffle 0

とすることでできます。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?