0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PytorchのData Augmentationではまった話

Posted at

PytorchのDataAugumentationで大いにハマりました。

torchvisionのtransforms.ImageFolderで以下のように変換処理を書いてトレーニングをしようとしていました。

    data_transform = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(
                image_size, scale=(0.5, 1.0)
            ),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=[-15, 15]),
            transforms.RandomErasing(0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]),
        'val': transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    }

    train_dataset = torchvision.datasets.ImageFolder(root=train_image_dir, transform=data_transform['train'], loader=myloader)
    val_dataset = torchvision.datasets.ImageFolder(root=val_image_dir, transform=data_transform['val'], loader=myloader)

ただ、トレーニングを始めようとすると以下のエラーを出します。

AttributeError: shape. Did you mean: 'save'?

原因は、transforms.RandomErasingの場所でした。
Pytorchの公式ドキュメントを読むと

Randomly selects a rectangle region in an torch Tensor image and erases its pixels. This transform does not support PIL Image.

とあり、RandomErasingはPILimageには適応できないと書いてありました。
なので、ToTensorでTensorに変換した後に記述しないとうまく読み取ってくれないようです。

以下のように直すとうまく動きました!

    data_transform = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(
                image_size, scale=(0.5, 1.0)
            ),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=[-15, 15]),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.RandomErasing(0.5),
        ]),
        'val': transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    }

    train_dataset = torchvision.datasets.ImageFolder(root=train_image_dir, transform=data_transform['train'], loader=myloader)
    val_dataset = torchvision.datasets.ImageFolder(root=val_image_dir, transform=data_transform['val'], loader=myloader)

結局、データセットが悪いのか、モデルが悪いのか、、、学習はうまくいきませんでしたが、、、、

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?