LoginSignup
3
4

More than 1 year has passed since last update.

webdatasetの使い方上級編3:shard読み込みをlightningのDistributedDataParallel (DDP)で

Posted at

これはwebdatasetの使い方の続編です.

この記事では,pytorch lightningのdistributed data parallel (DDP)でwebdatasetを使う方法を説明します.data parallel (DP)の方法は別記事で説明していますので,まずはそちらを見てからまた戻ってきてください.

shardの作成

以下の記事で,multiprocessingを用いて並列ワーカーでshardを作成します.

DataParallel (DP)の説明

以下の記事で,shard読み込みをdata parallel (DP)で行う方法を説明しています.先にそちらを参照してくだし.この記事の説明は,そこからの差分です.

コード全体

上記のように,shardからデータを読み込んだあと,shard中のサンプル番号を表示することで,どの順番に読み込まれているのかを検証します.そのためにlockオブジェクトを共有して,並列ワーカーでの表示が混乱しないようにします.

lockを使う場合の全体のコードは長いので折りたたんでおきます.lockを使わないcleanなコードは最後に掲載します.
全体コードのuse_jpeg_shards_DDP.py

from multiprocessing.managers import SyncManager
from pathlib import Path
from torch import optim, nn
from torchvision import transforms, models
import argparse
import json
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar
import torch
import webdataset as wds
import torch.distributed as dist
import multiprocessing
from tqdm import tqdm

import warnings
warnings.simplefilter('ignore', UserWarning)

# https://bugs.python.org/issue7503
# https://stackoverflow.com/questions/28318502/pythonusing-multiprocessing-manager-in-process-pool
multiprocessing.current_process().authkey = b'this is the key'


def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified values of k
    https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L411
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res if len(res) > 1 else res[0]


def info_from_json(shard_path):
    json_file = Path(shard_path).glob('*.json')
    json_file = str(next(json_file))  # get the first json file
    with open(json_file, 'r') as f:
        info_dic = json.load(f)

    return info_dic['dataset size'], info_dic['num_classes']


def get_transform():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Lambda(lambda x: x / 255.),  # already tensor
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])


class MyModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, im, batch_dic, lock):
        bs = im.shape[0]
        gpu_id = im.get_device()
        gpu_id = torch.tensor([gpu_id] * bs, device=im.device)

        with lock:
            print('|-GPU------', gpu_id, '-----------------')
            print('| worker id', batch_dic['read worker id'])
            print('| shard    ', batch_dic['url'])
            print('| count    ', batch_dic['count'])
            # print('| label    ', batch_dic['label'])
            print('| rank     ', batch_dic['rank'])
            print('| worldsize', batch_dic['world size'])

        return self.model(im), gpu_id


class MyProgressBar(TQDMProgressBar):
    # https://github.com/Lightning-AI/lightning/blob/f576ed3bbda95a5045edacc49146a3f1cdcd892a/src/pytorch_lightning/callbacks/progress/base.py#L234
    def get_metrics(self, trainer, model):
        # don't show the version number
        items = super().get_metrics(trainer, model)
        items.pop('v_num', None)
        return items


class MyLightningModel(pl.LightningModule):
    def __init__(self, model, lock, args):
        super().__init__()
        self.model = model
        self.args = args
        self.criterion = nn.CrossEntropyLoss()
        self.lock_list = [lock]
        self.is_lock_set = False

    def forward(self, im):
        output = self.model(im)
        return output

    def set_lock(self):
        if self.is_lock_set:
            return

        self.rank = dist.get_rank()
        dist.broadcast_object_list(self.lock_list, src=0, device=None)
        self.lock = self.lock_list[0]  # shared lock for DDP
        tqdm.set_lock(self.lock)  # global lock of tqdm in lightning
        self.is_lock_set = True

    def training_step(self, batch, batch_idx):
        self.set_lock()

        im, batch_dic, urls = batch
        label = batch_dic['label']

        batch_dic['url'] = urls
        gpu_id = im.get_device()

        with self.lock:
            print('==========================')
            print(f'loop {batch_idx} on GPU {gpu_id} at rank {self.rank}:')
            print('worker id', batch_dic['read worker id'])
            print('shard    ', batch_dic['url'])
            print('count    ', batch_dic['count'])
            # print('label    ', batch_dic['label'])
            print('rank     ', batch_dic['rank'])
            print('worldsize', batch_dic['world size'])

            output, gpu_id = self.model(im, batch_dic, self.lock)

            print('proc GPU ', gpu_id)

        loss = self.criterion(output, label)

        top1 = accuracy(output, label)
        self.log('train top1', top1, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(),
                               lr=self.args.lr, betas=self.args.betas)
        return optimizer


def add_worker_id(sample):
    info = torch.utils.data.get_worker_info()
    sample['read worker id'] = info.id
    sample['rank'] = dist.get_rank()
    sample['world size'] = dist.get_world_size()
    return sample


def make_dataset(
    shards_url,
    batch_size,
    shuffle_buffer_size=-1,
    transform=None,
):

    dataset = wds.WebDataset(
        shards_url,
        nodesplitter=wds.split_by_node
    )
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.decode('torchrgb')  # jpg --> tensor(uint8, CHW)
    dataset = dataset.to_tuple(
        'jpg',
        'json',
        '__url__',
    )
    dataset = dataset.map_tuple(
        lambda x: transform(x) if transform is not None else x,
        add_worker_id,
        lambda x: int(x.split('.')[0].split('-')[-1]),  # 'test-00.tar' --> 0
    )
    dataset = dataset.batched(
        batch_size,
        partial=False)

    return dataset


def my_collate_fn(batch):
    ret = (
        batch[0],  # 'jpg', already BCHW because of dataset.batched()
        torch.utils.data.default_collate(batch[1]),  # 'json'
        torch.utils.data.default_collate(batch[2]),  # '__url__'
    )
    return ret


def main(args):

    assert torch.cuda.is_available(), 'cpu is not supported'
    assert isinstance(args.gpu, list), 'single gpu is not supported'

    shards_path = [
        str(path) for path in Path(args.shard_path).glob('*.tar')
        if not path.is_dir()
    ]

    transform = get_transform()

    dataset = make_dataset(
        shards_url=shards_path,
        batch_size=args.batch_size,
        shuffle_buffer_size=args.shuffle,
        transform=transform)
    sample_loader = wds.WebLoader(
        dataset,
        batch_size=None,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=my_collate_fn)

    dataset_size, num_classes = info_from_json(args.shard_path)
    num_batches = dataset_size // (args.batch_size * len(args.gpu))

    sample_loader.length = num_batches
    sample_loader = sample_loader.with_length(num_batches)

    sample_loader = sample_loader.repeat(nbatches=num_batches)
    sample_loader = sample_loader.slice(num_batches)

    model = MyModel(num_classes=num_classes)

    with SyncManager() as manager:

        lock = manager.RLock()
        model_lightning = MyLightningModel(model, lock, args)

        trainer = pl.Trainer(
            devices=args.gpu,
            accelerator='gpu',
            # strategy='ddp',
            strategy='ddp_find_unused_parameters_false',
            # strategy='ddp_spawn',
            max_epochs=args.n_epochs,
            callbacks=[
                MyProgressBar(),
            ])
        trainer.fit(
            model=model_lightning,
            train_dataloaders=sample_loader)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-s', '--shard_path', action='store',
                        default='./test_shards/',
                        help='Path to the dir to store shard *.tar files.')
    parser.add_argument('--shuffle', type=int, default=-1,
                        help='shuffle buffer size. negative means no shuffle. '
                        'default -1')

    parser.add_argument('-b', '--batch_size', type=int, default=3,
                        help='batch size. default 3')
    parser.add_argument('-w', '--num_workers', type=int, default=2,
                        help='number of dataloader workders. default 2')
    parser.add_argument('-g', '--gpu', nargs='+', type=int, default=0,
                        help='GPU ids to be used. '
                        'int ("0", "1") or list of int ("1 2", "0 1 2"). '
                        'default "0"')

    parser.add_argument('--n_epochs', type=int, default=10,
                        help='number of epochs. default to 10')
    parser.add_argument('-lr', type=float, default=0.0001,
                        help='learning rate. default to 0.0001')
    parser.add_argument('--betas', nargs='+', type=float, default=[0.9, 0.999],
                        help='betas of Adam. default to (0.9, 0.999).'
                        'specify like --betas 0.9 0.999')

    args = parser.parse_args()
    print(args)
    main(args)

main()での処理

ほぼDPの場合と同様です.

複数GPUのみ対応するところだけ異なる.
    assert torch.cuda.is_available(), 'cpu is not supported'
    assert isinstance(args.gpu, list), 'single gpu is not supported'

webdatasetパイプラインの作成

datasetオブジェクトの作成もほぼDPの場合と同様です.
違いは,wds.WebDataset()の作成オプションのみ.

DPの場合.
    dataset = wds.WebDataset(shards_url)
DDPの場合.
    dataset = wds.WebDataset(
        shards_url,
        nodesplitter=wds.split_by_node
    )

DDPの場合にはノード(GPU)毎にプロセスを起動するので,shardリストも分割して別々にノードに振り分けます.その指定がnodesplitterです.

wds.split_by_node()の中身は,単純に先頭から各GPUへとshardファイルを割り当てているだけです.

例えば3 GPUの場合でshardファイルが

  • shared0.tar, shard1.tar, ...

だとすると,

  • GPU0の担当:shard0.tar, shard3.tar, ...
  • GPU1の担当:shard1.tar, shard4.tar, ...
  • GPU2の担当:shard2.tar, shard5.tar, ...

になります.

ちなみに

workder idとrankの追加

json情報をデコードした辞書には関数add_workder_idを適用しました.

DPのときにはworker idを追加しただけでしたが,DDPではrankとworld sizeも追加しておきます(確認用です).

def add_worker_id(sample):
    info = torch.utils.data.get_worker_info()
    sample['read worker id'] = info.id
    sample['rank'] = dist.get_rank()
    sample['world size'] = dist.get_world_size()
    return sample

data loaderの準備

data loaderにはwds.WebLoaderオブジェクトを使います.これもDPの場合とほぼ同様ですが,オブジェクト生成後の設定が若干異なります.

DPの場合
    sample_loader = wds.WebLoader(
        dataset,
        batch_size=None,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=my_collate_fn)

    dataset_size, num_classes = info_from_json(args.shard_path)
    num_batches = dataset_size // args.batch_size + 1

    sample_loader.length = num_batches
    sample_loader = sample_loader.with_length(num_batches)

DPの場合は

  • num_batches = dataset_size // args.batch_size + 1

にしていました.ワーカー数やGPU数に関係なく,指定batch sizeのbatchをnum_batches分取り出せばOKでした(バッチ毎にworkerが交替しますが).

DDPの場合
    sample_loader = wds.WebLoader(
        dataset,
        batch_size=None,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=my_collate_fn)

    dataset_size, num_classes = info_from_json(args.shard_path)
    num_batches = dataset_size // (args.batch_size * len(args.gpu))

    sample_loader.length = num_batches
    sample_loader = sample_loader.with_length(num_batches)

    sample_loader = sample_loader.repeat(nbatches=num_batches)
    sample_loader = sample_loader.slice(num_batches)

DDPの場合の変更点は2つ.

1つ目はバッチ数です.

DDPのバッチ数設定
num_batches = dataset_size // (args.batch_size * len(args.gpu))

nodesplitter=wds.split_by_nodeのところでも説明しましたが,shardファイルは各GPUに割り振られるので,各GPUの担当するshardファイル数は「データセットサイズ / GPU数」になります.各GPUではそれをバッチ数で割った数の反復が行われます.そこで「バッチサイズ * GPU数」でデータセットサイズを割った値が,各GPUでの反復回数(1エポックのバッチ数)です.

2つ目はrepeatsliceの追加です.

DDPのrepeatとslice
    sample_loader = sample_loader.repeat(nbatches=num_batches)
    sample_loader = sample_loader.slice(num_batches)

上で「各GPUではそれ(データセットサイズ / GPU数)をバッチ数で割った数の反復が行われます.」と書きましたが,実際はちょっと違います.

例えばshardファイルが5つ,GPUが3つだとします.
このとき,

  • GPU0の担当:shard0.tar, shard3.tar
  • GPU1の担当:shard1.tar, shard4.tar
  • GPU2の担当:shard2.tar

になります.つまりGPU2の担当するshardファイル数が少ないため,このままではGPU0,1と2で同じ反復回数にはなりません.

またGPU0とGPU1の担当shardファイル数が同じだとしても,shard中のサンプル数が同じとは限りません.

DDPの場合にはすべてのGPUで同じ反復回数であることを想定しているため,このままではGPU0かGPU1のどちらかが先にサンプルを使い切った時点で処理がハングアップしてしまいます.

そこで,どのGPUにおいてもnum_batches回の反復を保証するためにrepeatsliceを使います.

  • repeat(nbatches=num_batches)で,指定したバッチ数になるまで,バッチ生成を反復します.
    • もしサンプルを使い切ったら,また先頭のshardに戻ってバッチ生成します.
  • slice(num_batches)で,指定したバッチ数で打ち切ります.
    • サンプルを使い切っていたらrepeatで先頭に戻ってバッチが生成されていますが,それも含めてトータルのバッチ数がnbum_batchesになることを保証できます.

このロジックなら動作しますが,デメリットはエポック単位で全サンプルを漏れなく重複なく使うかどうかを保証できないことです.

  • あるGPUが担当するshardファイルから生成されるバッチ数の合計がnum_batchesよりも多い場合,使われないサンプルがでてきます.
    • この問題をある程度回避するには,あるエポックで使われないサンプルは別のエポックで使われるようにします.そのためにはshardファイルリストの時点で毎回shuffleする,shuffleバッファを使う,をすればOK
  • あるGPUが担当するshardファイルから生成されるバッチ数の合計がnum_batchesよりも少ない場合,先頭のshardから再度同じサンプルがバッチ生成に使われてしまいます.
    • これもshardファイルリストをshuffleすれば,重複して使われるサンプルがエポック毎にランダムに変わりますので,ある程度回避できます.

ちなみに

webdatasetにはDDPのために上と同様の処理をするddp_equalize()というものがあったようなのですが,現在はなぜか使えないようです.

またREADMEのMultinode Trainingでは,.with_epoch()を使う方法が紹介されていますが,おそらくエポック単位でval_lossなどを評価したい場合には不向きです.

Lightningモジュールの使用

DDPの実装方法はいくつかありますが,ここではPyTorch Lightningddpを利用します.

  • 単一マシン(ノード)の複数DDPしか対応していません.
    • できるかもしれませんがやってません...
  • ddpにだけ対応です.
    • ddp_spawnは非対応です.
      • webedatasetオブジェクトはpickle化できるのですが,lambdaはpickle化できません.
      • labmdaをコードから排除してもうまく動作しませんでした...

まずはLightningモジュールの派生クラスを作成します.

LightningModule

以下のコードはlockに対応しているのでやや複雑になっていますが,lockなしcleanなバージョンはスッキリしてます(最後に掲載).

class MyLightningModel(pl.LightningModule):
    def __init__(self, model, lock, args):
        super().__init__()
        self.model = model
        self.args = args
        self.criterion = nn.CrossEntropyLoss()
        self.lock_list = [lock]
        self.is_lock_set = False

modelとargsは一般的な使用方法です.

lockを引数にとって,lock_listとis_lock_setフラグを初期化します.

    def forward(self, im):
        output = self.model(im)
        return output

forwardは単にモデルのforwardを呼んでいるだけです.

    def set_lock(self):
        if self.is_lock_set:
            return

        self.rank = dist.get_rank()
        dist.broadcast_object_list(self.lock_list, src=0, device=None)
        self.lock = self.lock_list[0]  # shared lock for DDP
        tqdm.set_lock(self.lock)  # global lock of tqdm in lightning
        self.is_lock_set = True

print表示が競合しないようにlockを使うためだけの関数で,training_stepから1回だけ呼び出します.(本当ならon_fit_start()あたりのcallbackを使ったほうがスッキリすると思いますが...)

  • dist.broadcast_object_list(self.lock_list, src=0, device=None): rank0から他のrankへ,lockオブジェクトを送信します.objectではなくlistしか送信できないようなので,initで[lock]としていました.tensorではないのでdeviceは不使用.
  • self.lock = self.lock_list[0]: rank0から送られてきたlistの要素をself.lockに設定します.これを後で利用します.
  • tqdm.set_lock(self.lock): おなじlockオブジェクトをtqdmにも設定しておきます.lightningのtqdm表示が,コード中のprint表示と競合しないためのものです.

なお以下のようにコード先頭でmultiprocessingのauthkeyを設定しています.

# https://bugs.python.org/issue7503
# https://stackoverflow.com/questions/28318502/pythonusing-multiprocessing-manager-in-process-pool
multiprocessing.current_process().authkey = b'this is the key'

これをしないとbroadcast_object_listでauthkeyがないというエラーになりますので注意.

    def training_step(self, batch, batch_idx):
        self.set_lock()

        im, batch_dic, urls = batch
        label = batch_dic['label']

        batch_dic['url'] = urls
        gpu_id = im.get_device()

        with self.lock:
            print('==========================')
            print(f'loop {batch_idx} on GPU {gpu_id}:')
            print('worker id', batch_dic['read worker id'])
            print('shard    ', batch_dic['url'])
            print('count    ', batch_dic['count'])
            # print('label    ', batch_dic['label'])
            print('rank     ', batch_dic['rank'])
            print('worldsize', batch_dic['world size'])

            output, gpu_id = self.model(im, batch_dic, self.lock)

            print('proc GPU ', gpu_id)

        loss = self.criterion(output, label)

        top1 = accuracy(output, label)
        self.log('train top1', top1, prog_bar=True)

        return loss

学習ループ本体のtraining_stepは,DPの場合の学習ループ本体とほぼ同じです.違いはAverageMeterの代わりにtrainer.logを使っていることだけです.

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(),
                               lr=self.args.lr, betas=self.args.betas)
        return optimizer

optimizerを設定してクラス定義は終了です.

Model

モデルはDPの場合とまったく同じです.

trainer.fit

ではmainの最後の処理です.

    model = MyModel(num_classes=num_classes)

    with SyncManager() as manager:

        lock = manager.RLock()

DPの場合と同様に,managerオブジェクトからlockを生成します.

        model_lightning = MyLightningModel(model, lock, args)

        trainer = pl.Trainer(
            devices=args.gpu,
            accelerator='gpu',
            strategy='ddp',
            # strategy='ddp_find_unused_parameters_false',
            max_epochs=args.n_epochs,
            callbacks=[
                MyProgressBar(),
            ])
        trainer.fit(
            model=model_lightning,
            train_dataloaders=sample_loader)

さきほど定義したMyLightningModel()のオブジェクトを生成し,trainerに設定します.

  • devices=args.gpu: 使用するGPU番号リストを設定します.

  • strategy='ddp': コードはDDPのみ対応です.spawnは対応していません.

    • strategy='ddp_find_unused_parameters_false': DDPでwarningが出るので,それを抑制するならこちら.
  • callbacks=[MyProgressBar()]): tqdmプログレスバーからv_numを取り除くだけのcallbackです

  • train_dataloaders=sample_loader: WebLoaderオブジェクトを指定します.

では実行してみる

3 GPUで実行してみます.

  • GPU番号:0, 1, 2
  • ワーカー数:2
  • バッチサイズ: 7
初期化
$ python use_jpeg_shards_lightning_DDP.py -s ./shards_cats_dogs/ -g 0 1 2 -b 7 -w 2 
Namespace(batch_size=7, betas=[0.9, 0.999], gpu=[0, 1, 2], lr=0.0001, n_epochs=10, num_workers=2, shard_path='./shards_cats_dogs/', shuffle=-1)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/3
Namespace(batch_size=7, betas=[0.9, 0.999], gpu=[0, 1, 2], lr=0.0001, n_epochs=10, num_workers=2, shard_path='./shards_cats_dogs/', shuffle=-1)
Namespace(batch_size=7, betas=[0.9, 0.999], gpu=[0, 1, 2], lr=0.0001, n_epochs=10, num_workers=2, shard_path='./shards_cats_dogs/', shuffle=-1)
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/3
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/3
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 3 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | MyModel          | 11.2 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)
Epoch 0:   0%|      | 0/1190 [00:00<?, ?it/s]

rankが0, 1, 2のworld sizeが設定されました.

では各GPU(rank)の0ループ目の出力を確認しましょう.

0ループ目
==========================
loop 0 on GPU 0 at rank 0:
worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
shard     tensor([18, 18, 18, 18, 18, 18, 18], device='cuda:0')
count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')
rank      tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:0')
|-GPU------ tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
| shard     tensor([18, 18, 18, 18, 18, 18, 18], device='cuda:0')
| count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')
| rank      tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:0')
proc GPU  tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
==========================
loop 0 on GPU 1 at rank 1:
worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:1')
shard     tensor([19, 19, 19, 19, 19, 19, 19], device='cuda:1')
count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:1')
rank      tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:1')
|-GPU------ tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1') -----------------
| worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:1')
| shard     tensor([19, 19, 19, 19, 19, 19, 19], device='cuda:1')
| count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:1')
| rank      tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:1')
==========================
loop 0 on GPU 2 at rank 2:
worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:2')
shard     tensor([15, 15, 15, 15, 15, 15, 15], device='cuda:2')
count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:2')
rank      tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:2')
|-GPU------ tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2') -----------------
| worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:2')
| shard     tensor([15, 15, 15, 15, 15, 15, 15], device='cuda:2')
| count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:2')
| rank      tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:2')
proc GPU  tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')

各GPUでは別々のshardが使われており,worker idはすべて0ということが分かります.
これは同じworkerが使われているわけではなく,各GPUで別々にworker番号が割り振られているだけです.

では次のループの出力を確認します.

1ループ目
=========================
loop 1 on GPU 1 at rank 1:
worker id tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
shard     tensor([10, 10, 10, 10, 10, 10, 10], device='cuda:1')
count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:1')
rank      tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:1')
|-GPU------ tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1') -----------------
| worker id tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
| shard     tensor([10, 10, 10, 10, 10, 10, 10], device='cuda:1')
| count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:1')
| rank      tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:1')
proc GPU  tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
==========================
loop 1 on GPU 0 at rank 0:
worker id tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:0')
shard     tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:0')
count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')
rank      tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:0')
|-GPU------ tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:0')
| shard     tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:0')
| count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')
| rank      tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:0')
proc GPU  tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
==========================
loop 1 on GPU 2 at rank 2:
worker id tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:2')
shard     tensor([24, 24, 24, 24, 24, 24, 24], device='cuda:2')
count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:2')
rank      tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:2')
|-GPU------ tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2') -----------------
| worker id tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:2')
| shard     tensor([24, 24, 24, 24, 24, 24, 24], device='cuda:2')
| count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:2')
| rank      tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:2')
proc GPU  tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')

各GPUではworker idはすべて1,つまりGPUあたり2ワーカーが交替でバッチ生成をしている事がわかります.

  • GPU 0
    • worker 0: shard 18担当
    • worker 1: shard 10担当
  • GPU 1
    • worker 0: shard 19担当
    • worker 1: shard 3担当
  • GPU 2
    • worker 0: shard 15担当
    • worker 1: shard 24担当

2ループ目になると,各GPUではまたworker 0に交代して,同じshardから引き続きバッチを生成していることが確認できます.

2ループ目
==========================
loop 2 on GPU 1 at rank 1:
worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:1')
shard     tensor([19, 19, 19, 19, 19, 19, 19], device='cuda:1')
count     tensor([ 7,  8,  9, 10, 11, 12, 13], device='cuda:1')
rank      tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:1')
|-GPU------ tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1') -----------------
| worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:1')
| shard     tensor([19, 19, 19, 19, 19, 19, 19], device='cuda:1')
| count     tensor([ 7,  8,  9, 10, 11, 12, 13], device='cuda:1')
| rank      tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:1')
proc GPU  tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:1')
==========================
loop 2 on GPU 0 at rank 0:
worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
shard     tensor([18, 18, 18, 18, 18, 18, 18], device='cuda:0')
count     tensor([ 7,  8,  9, 10, 11, 12, 13], device='cuda:0')
rank      tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:0')
|-GPU------ tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
| shard     tensor([18, 18, 18, 18, 18, 18, 18], device='cuda:0')
| count     tensor([ 7,  8,  9, 10, 11, 12, 13], device='cuda:0')
| rank      tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:0')
proc GPU  tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
==========================
loop 2 on GPU 2 at rank 2:
worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:2')
shard     tensor([15, 15, 15, 15, 15, 15, 15], device='cuda:2')
count     tensor([ 7,  8,  9, 10, 11, 12, 13], device='cuda:2')
rank      tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')
worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:2')
|-GPU------ tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2') -----------------
| worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:2')
| shard     tensor([15, 15, 15, 15, 15, 15, 15], device='cuda:2')
| count     tensor([ 7,  8,  9, 10, 11, 12, 13], device='cuda:2')
| rank      tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')
| worldsize tensor([3, 3, 3, 3, 3, 3, 3], device='cuda:2')
proc GPU  tensor([2, 2, 2, 2, 2, 2, 2], device='cuda:2')

以降のループも同様です.

ちなみにnum_workersで指定した数のワーカーが各GPU(rank)に割り当てられるため,DPの場合と同様に

  • ワーカー数=コア数(num_workers=os.cpu_count())

としてしまうと,総ワーカー数がコア数*GPU数にもなってしまいます.そのため,

  • ワーカー数=コア数 / GPU数(num_workers=os.cpu_count() / len(gpu)

程度にしておけばよいことになります.

lockなしクリーンなコード全体

lockを使わない普通の場合の全体のコードのサンプルです.折りたたんでおきます.
全体コードのuse_jpeg_shards_DDP_clean.py

from pathlib import Path
from torch import optim, nn
from torchvision import transforms, models
import argparse
import json
import pytorch_lightning as pl
# from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks.progress import TQDMProgressBar
import torch
import webdataset as wds

import warnings
warnings.simplefilter('ignore', UserWarning)


def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified values of k
    https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L411
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res if len(res) > 1 else res[0]


def info_from_json(shard_path):
    json_file = Path(shard_path).glob('*.json')
    json_file = str(next(json_file))  # get the first json file
    with open(json_file, 'r') as f:
        info_dic = json.load(f)

    return info_dic['dataset size'], info_dic['num_classes']


def get_transform():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Lambda(lambda x: x / 255.),  # already tensor
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])


class MyModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, im):
        return self.model(im)


class MyProgressBar(TQDMProgressBar):
    # https://github.com/Lightning-AI/lightning/blob/f576ed3bbda95a5045edacc49146a3f1cdcd892a/src/pytorch_lightning/callbacks/progress/base.py#L234
    def get_metrics(self, trainer, model):
        # don't show the version number
        items = super().get_metrics(trainer, model)
        items.pop('v_num', None)
        return items


class MyLightningModel(pl.LightningModule):
    def __init__(self, model, args):
        super().__init__()
        self.model = model
        self.args = args
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, im):
        output = self.model(im)
        return output

    def training_step(self, batch, batch_idx):

        im, label = batch

        output = self.model(im)

        loss = self.criterion(output, label)

        top1 = accuracy(output, label)
        self.log('train top1', top1, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(),
                               lr=self.args.lr, betas=self.args.betas)
        return optimizer


def make_dataset(
    shards_url,
    batch_size,
    shuffle_buffer_size=-1,
    transform=None,
):

    dataset = wds.WebDataset(
        shards_url,
        nodesplitter=wds.split_by_node
    )
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.decode('torchrgb')  # jpg --> tensor(uint8, CHW)
    dataset = dataset.to_tuple(
        'jpg',
        'json',
    )
    dataset = dataset.map_tuple(
        lambda x: transform(x) if transform is not None else x,
        lambda x: x['label']
    )
    dataset = dataset.batched(
        batch_size,
        partial=False)

    return dataset


def my_collate_fn(batch):
    ret = (
        batch[0],  # 'jpg', already BCHW because of dataset.batched()
        torch.utils.data.default_collate(batch[1]),  # label
    )
    return ret


def main(args):

    assert torch.cuda.is_available(), 'cpu is not supported'
    assert isinstance(args.gpu, list), 'single gpu is not supported'

    shards_path = [
        str(path) for path in Path(args.shard_path).glob('*.tar')
        if not path.is_dir()
    ]

    transform = get_transform()

    dataset = make_dataset(
        shards_url=shards_path,
        batch_size=args.batch_size,
        shuffle_buffer_size=args.shuffle,
        transform=transform)
    sample_loader = wds.WebLoader(
        dataset,
        batch_size=None,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=my_collate_fn)

    dataset_size, num_classes = info_from_json(args.shard_path)
    num_batches = dataset_size // (args.batch_size * len(args.gpu))

    sample_loader.length = num_batches
    sample_loader = sample_loader.with_length(num_batches)

    sample_loader = sample_loader.repeat(nbatches=num_batches)
    sample_loader = sample_loader.slice(num_batches)

    model = MyModel(num_classes=num_classes)

    model_lightning = MyLightningModel(model, args)

    trainer = pl.Trainer(
        devices=args.gpu,
        accelerator='gpu',
        strategy='ddp',  # 'ddp_find_unused_parameters_false',
        max_epochs=args.n_epochs,
        callbacks=[
            MyProgressBar(),
        ])
    trainer.fit(
        model=model_lightning,
        train_dataloaders=sample_loader)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-s', '--shard_path', action='store',
                        default='./test_shards/',
                        help='Path to the dir to store shard *.tar files.')
    parser.add_argument('--shuffle', type=int, default=-1,
                        help='shuffle buffer size. negative means no shuffle. '
                        'default -1')

    parser.add_argument('-b', '--batch_size', type=int, default=3,
                        help='batch size. default 3')
    parser.add_argument('-w', '--num_workers', type=int, default=2,
                        help='number of dataloader workders. default 2')
    parser.add_argument('-g', '--gpu', nargs='+', type=int, default=0,
                        help='GPU ids to be used. '
                        'int ("0", "1") or list of int ("1 2", "0 1 2"). '
                        'default "0"')

    parser.add_argument('--n_epochs', type=int, default=10,
                        help='number of epochs. default to 10')
    parser.add_argument('-lr', type=float, default=0.0001,
                        help='learning rate. default to 0.0001')
    parser.add_argument('--betas', nargs='+', type=float, default=[0.9, 0.999],
                        help='betas of Adam. default to (0.9, 0.999).'
                        'specify like --betas 0.9 0.999')

    args = parser.parse_args()
    print(args)
    main(args)
3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4