LoginSignup
103
74

More than 3 years have passed since last update.

【挑戦者求ム】ぼくの考える最速のData LoadingとData Augmentation(Kaggle notebook)

Posted at

はじめに

まぁタイトルの通りなのですが、Kaggle notebook上で行う最速のData LoadingとData Augmentationを考えてみたので紹介します。より速い方法を知っている方は教えてください!
今回の題材は以下のように設定します。

  • データ
  • 実行環境
    • GPUをenableにしたKaggle notebookで行います。
    • 2 CPU cores
    • 13 GB RAM
    • Tesla P100
  • 条件
    • trainデータ(画像とラベル)をすべてTensorにしてGPUにLoadするのにかかる時間を計測する
    • バッチサイズは64
    • 前処理 & Data Augmentationとして以下の処理をかける。(異なるライブラリ間でできるだけ動作を揃えられるような処理だけ選びました。)
      • RandomResizedCrop
      • HorizontalFlip
      • VerticalFlip
      • MotionBlur
      • Rotate
      • Normalize

また、この記事で用いたコードはこちらのnotebookから試すことができます。(あまりデバッグしていないので不備があればご指摘ください。。。)
https://www.kaggle.com/hirune924/the-fastest-data-loading-data-augmentation?scriptVersionId=41763394

また、結果が知りたくてたまらない人のために先に結果を貼っておきます。
ダウンロード.png

OpenCV + Albumentations

今ではかなり基本的な組み合わせかと思います。torchvisonのTransformよりこちらを使ってる人も多いのではないでしょうか?自分もKaggleコンペの初手ではこの組み合わせを使うことが多いです。

cv2_alb.py
import time
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import cv2
import albumentations as A


class DogDataset(Dataset):
    def __init__(self, transform=None):
        self.img_list = pd.read_csv('../input/dog-breed-identification/labels.csv')
        self.transform = transform

        breeds=list(self.img_list['breed'].unique())
        self.breed2idx = {b: i for i, b in enumerate(breeds)}

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_row = self.img_list.iloc[idx]
        image = cv2.imread('../input/dog-breed-identification/train/' + img_row['id'] + '.jpg')
        label = self.breed2idx[img_row['breed']]

        if self.transform is not None:
            image = self.transform(image=image)
        image = torch.from_numpy(image['image'].transpose(2, 0, 1))
        return image, label

transform = A.Compose([A.RandomResizedCrop(height=224, width=224, p=1),
                       A.HorizontalFlip(p=0.5),
                       A.VerticalFlip(p=0.5),
                       A.MotionBlur(blur_limit=3, p=1),
                       A.Rotate(limit=45, p=1),
                       A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, always_apply=True, p=1.0)])

data_loader = DataLoader(DogDataset(transform=transform), batch_size=64, shuffle=True, num_workers=2)

これを以下のように読み込んで時間を計測します。

cv2_alb_time.py
%%timeit -r 2 -n 5
opencv_alb_times = []
start_time = time.time()
for image, label in data_loader:
    image = image.cuda()
    label = label.cuda()
    pass
opencv_alb_time = time.time() - start_time
opencv_alb_times.append(opencv_alb_time)
print(str(opencv_alb_time) + ' sec')

結果は以下のようになりました。

98.37442970275879 sec
70.52895092964172 sec
66.72178149223328 sec
61.30395317077637 sec
68.30901885032654 sec
69.6796133518219 sec
71.02722263336182 sec
70.88462662696838 sec
70.54376363754272 sec
65.67756700515747 sec
1min 11s ± 1.74 s per loop (mean ± std. dev. of 2 runs, 5 loops each)

jpeg4py + Albumentations

この手の記事を書くと必ずjpeg4pyを使ってみてよと言う人がいるのでこちらもベースラインとして計測しておきます。まぁ画像の形式がjpegならこれを使わない手はないですね。
まずはインストールから

install_jpeg4py.sh
!apt-get install libturbojpeg
!pip install jpeg4py

続いてコードはこちらです。データの読み込み以外はほとんど同じですね。

jpeg4py_alb.py
import time
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import cv2
import albumentations as A
import jpeg4py as jpeg


class DogDataset(Dataset):
    def __init__(self, transform=None):
        self.img_list = pd.read_csv('../input/dog-breed-identification/labels.csv')
        self.transform = transform

        breeds=list(self.img_list['breed'].unique())
        self.breed2idx = {b: i for i, b in enumerate(breeds)}

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_row = self.img_list.iloc[idx]
        image = jpeg.JPEG('../input/dog-breed-identification/train/' + img_row['id'] + '.jpg').decode()
        label = self.breed2idx[img_row['breed']]

        if self.transform is not None:
            image = self.transform(image=image)
        image = torch.from_numpy(image['image'].transpose(2, 0, 1))
        return image, label

transform = A.Compose([A.RandomResizedCrop(height=224, width=224, p=1),
                       A.HorizontalFlip(p=0.5),
                       A.VerticalFlip(p=0.5),
                       A.MotionBlur(blur_limit=3, p=1),
                       A.Rotate(limit=45, p=1),
                       A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, always_apply=True, p=1.0)])

data_loader = DataLoader(DogDataset(transform=transform), batch_size=64, shuffle=True, num_workers=2)

読み込んで時間を計測するコードは先ほどのものとほぼ同じなので省略します。結果は以下のようになりました。やっぱりjpeg4pyは早いですね。

43.14848828315735 sec
42.78340029716492 sec
41.33797478675842 sec
43.24748754501343 sec
41.11549472808838 sec
41.17329430580139 sec
40.58435940742493 sec
41.16935634613037 sec
40.92542815208435 sec
39.6163330078125 sec
41.5 s ± 816 ms per loop (mean ± std. dev. of 2 runs, 5 loops each)

jpeg4py + Kornia

Data AugmentationにKorniaを使用します。これによりData Augmentationの処理をGPU上で行うことが可能になります。形の揃った画像でテンソルのバッチを作れるように最初のRandomResizedCropだけAlbumentationsを使用します。Korniaはバッチごと処理を行うことが可能なのでDataLoaderで読み込まれたバッチに対してDataAugmentationを実行します。

コードはこちらです。

jpeg4py_kornia.py
import time
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import cv2
import jpeg4py as jpeg

import albumentations as A
import kornia.augmentation as K
import torch.nn as nn


class DogDataset(Dataset):
    def __init__(self, transform=None):
        self.img_list = pd.read_csv('../input/dog-breed-identification/labels.csv')
        self.transform = transform

        breeds=list(self.img_list['breed'].unique())
        self.breed2idx = {b: i for i, b in enumerate(breeds)}

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_row = self.img_list.iloc[idx]
        image = jpeg.JPEG('../input/dog-breed-identification/train/' + img_row['id'] + '.jpg').decode()
        label = self.breed2idx[img_row['breed']]

        if self.transform is not None:
            image = self.transform(image=image)
        image = torch.from_numpy(image['image'].transpose(2, 0, 1).astype(np.float32))
        return image, label

alb_transform = A.Compose([A.RandomResizedCrop(height=224, width=224, p=1)])

mean_std = torch.Tensor([0.5, 0.5, 0.5])*255
kornia_transform = nn.Sequential(
    K.RandomHorizontalFlip(),
    K.RandomVerticalFlip(),
    K.RandomMotionBlur(3, 35., 0.5),
    K.RandomRotation(degrees=45.0),
    K.Normalize(mean=mean_std,std=mean_std)
)

data_loader = DataLoader(DogDataset(transform=alb_transform), batch_size=64, shuffle=True, num_workers=2)

読み込みは以下のようになります。DataLoaderから読み込まれた後にバッチ単位で変換をかけているのが分かるかと思います。

jpeg4py_kornia_time.py
%%timeit -r 2 -n 5
jpeg4py_kornia_times = []
start_time = time.time()
for image, label in data_loader:
    image = kornia_transform(image.cuda())
    label = label.cuda()
    pass
jpeg4py_kornia_time = time.time() - start_time
jpeg4py_kornia_times.append(jpeg4py_kornia_time)
print(str(jpeg4py_kornia_time) + ' sec')

結果は以下のようになりました。かなり速くなってきましたね。

28.150899171829224 sec
24.104888916015625 sec
25.490058183670044 sec
24.111201763153076 sec
22.999730587005615 sec
25.16165590286255 sec
26.496272325515747 sec
27.150801420211792 sec
28.757362365722656 sec
29.331339836120605 sec
26.2 s ± 1.2 s per loop (mean ± std. dev. of 2 runs, 5 loops each)

DALI + Kornia

これが私が考える現在最速の組み合わせです。なんならこれを言いたいがためにこの記事を書きました。DALIを使うと画像を読み込む段階からGPUを利用し、読み込まれた時にはすでにGPU上に画像が乗っているということができます。これによりAlbumentationsなどは使いにくくなりDALIに実装されているAugmentationの種類が少ないため使いにくかったのですがKorniaによるGPU上でのAugmentationが実用レベルに達してきたためこの組み合わせが実現しました。
まずNVIDIA DALIをインストールします。

install_dali.sh
!pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100

そしてコードはこちらになります。今までのものとは少し毛色が違いますね。DALIでpipelineを定義して、それをビルドし、PyTorchのTensorを返すイテレータを作る感じです。RandomResizedCropはDALIで行い、それ以降はKorniaで行います。

dali_kornia.py
import time
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

import kornia.augmentation as K
import torch.nn as nn

from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import nvidia.dali.ops as ops
import nvidia.dali.types as types

class DALIPipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id):
        super(DALIPipeline, self).__init__(batch_size, num_threads, device_id)
        self.img_list = pd.read_csv('../input/dog-breed-identification/labels.csv')

        breeds=list(self.img_list['breed'].unique())
        self.breed2idx = {b: i for i, b in enumerate(breeds)}

        self.img_list['label'] = self.img_list['breed'].map(self.breed2idx)
        self.img_list['data'] = '../input/dog-breed-identification/train/' + self.img_list['id'] + '.jpg'

        self.img_list[['data', 'label']].to_csv('dali.txt', header=False, index=False, sep=' ')

        self.input = ops.FileReader(file_root='.', file_list='dali.txt')
        self.decode = ops.ImageDecoder(device = "mixed", output_type = types.DALIImageType.RGB)
        #self.decode = ops.ImageDecoderRandomCrop(device = "mixed", output_type = types.DALIImageType.RGB)
        self.resize = ops.RandomResizedCrop(device = "gpu", size=(224, 224))
        self.transpose = ops.Transpose(device='gpu', perm = [2, 0, 1])
        self.cast = ops.Cast(device='gpu', dtype=types.DALIDataType.FLOAT)

    def define_graph(self):
        images, labels = self.input(name="Reader")
        images = self.decode(images)
        images = self.resize(images)
        images = self.cast(images)
        output = self.transpose(images)
        return (output, labels)

def DALIDataLoader(batch_size):
    num_gpus = 1
    pipes = [DALIPipeline(batch_size=batch_size, num_threads=2, device_id=device_id) for device_id in range(num_gpus)]

    pipes[0].build()
    dali_iter = DALIGenericIterator(pipelines=pipes, output_map=['data', 'label'], 
                                    size=pipes[0].epoch_size("Reader"), reader_name=None, 
                                    auto_reset=True, fill_last_batch=True, dynamic_shape=False, 
                                    last_batch_padded=True)
    return dali_iter

data_loader = DALIDataLoader(batch_size=64)

mean_std = torch.Tensor([0.5, 0.5, 0.5])*255
kornia_transform = nn.Sequential(
    K.RandomHorizontalFlip(),
    K.RandomVerticalFlip(),
    K.RandomMotionBlur(3, 35., 0.5),
    K.RandomRotation(degrees=45.0),
    K.Normalize(mean=mean_std,std=mean_std)
)

読み込み部分は以下のようになります。

dali_kornia_time.py
%%timeit -r 2 -n 5
dali_kornia_times = []
start_time = time.time()
for feed in data_loader:
    # image is already on GPU
    image = kornia_transform(feed[0]['data'])
    label = feed[0]['label'].cuda()
    pass
dali_kornia_time = time.time() - start_time
dali_kornia_times.append(dali_kornia_time)
print(str(dali_kornia_time) + ' sec')

結果は次のようになります。爆速ですね!速すぎですね!もはやスピード違反ですね!

8.865531921386719 sec
7.996037721633911 sec
8.494542598724365 sec
8.241464853286743 sec
8.093241214752197 sec
8.12808108329773 sec
7.846079587936401 sec
7.849750280380249 sec
7.848227024078369 sec
7.633721828460693 sec
8.1 s ± 238 ms per loop (mean ± std. dev. of 2 runs, 5 loops each)

集計結果

簡単に棒グラフにしてみました。
ダウンロード.png

さいごに

今回はNVIDIA DALIとKorniaを用いた爆速のData LoadingとData Augmentationの方法を紹介しました。KorniaのAugmentationもAlbumentationsの充実度と比べると見劣りしますが、issueのロードマップにはAlbumentationsを意識した記述もあるため今後に期待です!
https://github.com/kornia/kornia/issues/434

103
74
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
103
74