1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

webdatasetの使い方上級編2:shard読み込みをDataParallel (DP)で

Last updated at Posted at 2023-01-02

これはwebdatasetの使い方の続編です.

この記事では,data parallel (DP)の学習ループでwebdatasetを使う方法を説明します.
distributed data parallel (DDP)の方法は別記事で説明しています.この記事でDPの説明を見ててから,DDPの記事を見てください.

shardの作成

以下の記事で,multiprocessingを用いて並列ワーカーでshardを作成します.これを使います.

これで作成されるshardの中のデータサンプルは以下のものです.

  • 'jpg': jpgファイル(__key__.jpg
  • 'json': jsonファイル(__key__.json
    • 'write worker id': このサンプルをshardに書き込んだワーカー番号(今回は使いません.デバッグ用)
    • 'count': このshard中のサンプルの番号(0〜).以下ではこの番号を使って,どの順番で読み込まれているのかを検証します.
    • 'category': サンプルのカテゴリ名(使いません.デバッグ用)
    • 'label': カテゴリ番号.学習に用います.

コード全体

上記のように,shardからデータを読み込んだあと,shard中のサンプル番号を表示することで,どの順番に読み込まれているのかを検証します.そのためにlockオブジェクトを共有して,並列ワーカーでの表示が混乱しないようにします.

lockを使う場合の全体のコードは長いので折りたたんでおきます.lockを使わないcleanなコードは最後に掲載します.
全体コードのuse_jpeg_shards_singe_or_DP.py
from multiprocessing.managers import SyncManager
from pathlib import Path
from torch import optim, nn
from torchvision import transforms, models
from tqdm import tqdm
import argparse
import json
import torch
import webdataset as wds

import warnings
warnings.simplefilter('ignore', UserWarning)


class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L363-L380
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.value = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, value, bs=1):
        if isinstance(value, torch.Tensor):
            value = value.item()
        self.value = value
        self.sum += value * bs
        self.count += bs
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified values of k
    https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L411
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res if len(res) > 1 else res[0]


def info_from_json(shard_path):
    json_file = Path(shard_path).glob('*.json')
    json_file = str(next(json_file))  # get the first json file
    with open(json_file, 'r') as f:
        info_dic = json.load(f)

    return info_dic['dataset size'], info_dic['num_classes']


def get_transform():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Lambda(lambda x: x / 255.),  # already tensor
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])


class MyModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, im, batch_dic, lock):
        bs = im.shape[0]
        gpu_id = im.get_device()
        gpu_id = torch.tensor([gpu_id] * bs, device=im.device)

        with lock:
            print('|-GPU------', gpu_id, '-----------------')
            print('| worker id', batch_dic['read worker id'])
            print('| shard    ', batch_dic['url'])
            print('| count    ', batch_dic['count'])
            # print('| label    ', batch_dic['label'])

        return self.model(im), gpu_id


def add_worker_id(sample):
    info = torch.utils.data.get_worker_info()
    sample['read worker id'] = info.id
    return sample


def make_dataset(
    shards_url,
    batch_size,
    shuffle_buffer_size=-1,
    transform=None,
):

    dataset = wds.WebDataset(shards_url)
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.decode('torchrgb')  # jpg --> tensor(uint8, CHW)
    dataset = dataset.to_tuple(
        'jpg',
        'json',
        '__url__',
    )
    dataset = dataset.map_tuple(
        lambda x: transform(x) if transform is not None else x,
        add_worker_id,
        lambda x: int(x.split('.')[0].split('-')[-1]),  # 'test-00.tar' --> 0
    )
    dataset = dataset.batched(
        batch_size,
        partial=False)

    return dataset


def my_collate_fn(batch):
    ret = (
        batch[0],  # 'jpg', already BCHW because of dataset.batched()
        torch.utils.data.default_collate(batch[1]),  # 'json'
        torch.utils.data.default_collate(batch[2]),  # '__url__'
    )
    return ret


def main(args):

    assert torch.cuda.is_available(), 'cpu is not supported'
    if isinstance(args.gpu, int):
        device = torch.device('cuda:' + str(args.gpu))
    elif isinstance(args.gpu, list):
        device = torch.device('cuda:' + str(args.gpu[0]))  # the 1st device

    shards_path = [
        str(path) for path in Path(args.shard_path).glob('*.tar')
        if not path.is_dir()
    ]

    transform = get_transform()

    dataset = make_dataset(
        shards_url=shards_path,
        batch_size=args.batch_size,
        shuffle_buffer_size=args.shuffle,
        transform=transform)
    sample_loader = wds.WebLoader(
        dataset,
        batch_size=None,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=my_collate_fn)

    dataset_size, num_classes = info_from_json(args.shard_path)
    num_batches = dataset_size // args.batch_size + 1

    sample_loader.length = num_batches
    sample_loader = sample_loader.with_length(num_batches)

    model = MyModel(num_classes=num_classes)
    if isinstance(args.gpu, list):
        model = torch.nn.DataParallel(model, device_ids=args.gpu)
    model.to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr, betas=args.betas)

    train_loss = AverageMeter()
    train_top1 = AverageMeter()

    with tqdm(range(args.n_epochs)) as pbar_epoch, \
            SyncManager() as manager:

        lock = manager.Lock()

        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Train] epoch: %d" % epoch)

            train_loss.reset()
            train_top1.reset()

            with tqdm(enumerate(sample_loader),
                      total=sample_loader.length,
                      leave=True,
                      smoothing=0,
                      ) as pbar_batch:

                for i, batch in pbar_batch:

                    im, batch_dic, urls = batch
                    im = im.to(device)
                    label = batch_dic['label'].to(device)

                    batch_dic['url'] = urls
                    gpu_id = im.get_device()

                    optimizer.zero_grad()

                    print('==========================')
                    print(f'loop {i} on GPU {gpu_id}:')
                    print('worker id', batch_dic['read worker id'])
                    print('shard    ', batch_dic['url'])
                    print('count    ', batch_dic['count'])
                    print('label    ', batch_dic['label'])

                    output, gpu_id = model(im, batch_dic, lock)
                    print('proc GPU ', gpu_id)

                    loss = criterion(output, label)
                    loss.backward()
                    optimizer.step()

                    bs = im.size(0)
                    train_loss.update(loss, bs)
                    train_top1.update(accuracy(output, label), bs)

                    pbar_batch.set_postfix_str(
                        ' loss={:6.04f}/{:6.04f}'
                        ' top1={:6.04f}/{:6.04f}'
                        ''.format(
                            train_loss.value, train_loss.avg,
                            train_top1.value, train_top1.avg,
                        ))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-s', '--shard_path', action='store',
                        default='./test_shards/',
                        help='Path to the dir to store shard *.tar files.')
    parser.add_argument('--shuffle', type=int, default=-1,
                        help='shuffle buffer size. negative means no shuffle. '
                        'default -1')

    parser.add_argument('-b', '--batch_size', type=int, default=3,
                        help='batch size. default 3')
    parser.add_argument('-w', '--num_workers', type=int, default=2,
                        help='number of dataloader workders. default 2')
    parser.add_argument('-g', '--gpu', nargs='+', type=int, default=0,
                        help='GPU ids to be used. '
                        'int ("0", "1") or list of int ("1 2", "0 1 2"). '
                        'default "0"')

    parser.add_argument('--n_epochs', type=int, default=10,
                        help='number of epochs. default to 10')
    parser.add_argument('-lr', type=float, default=0.0001,
                        help='learning rate. default to 0.0001')
    parser.add_argument('--betas', nargs='+', type=float, default=[0.9, 0.999],
                        help='betas of Adam. default to (0.9, 0.999).'
                        'specify like --betas 0.9 0.999')

    args = parser.parse_args()
    print(args)
    main(args)

まずはmainブロック

argparseの引数処理です.一般的な引数を設定してます.

  • -s:shardがあるディレクトリ.'test-0000.tar'というようにファイル名が連番になっていることを仮定します.
  • -g: 使用するGPU番号.
    • -g 0ならGPU0だけを使用.data parallelではない.
    • -g 0 1 2ならGPU0, 1, 2の3つを使用するdata parallel
mainブロック
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-s', '--shard_path', action='store',
                        default='./test_shards/',
                        help='Path to the dir to store shard *.tar files.')
    parser.add_argument('--shuffle', type=int, default=-1,
                        help='shuffle buffer size. negative means no shuffle. '
                        'default -1')

    parser.add_argument('-b', '--batch_size', type=int, default=3,
                        help='batch size. default 3')
    parser.add_argument('-w', '--num_workers', type=int, default=2,
                        help='number of dataloader workders. default 2')
    parser.add_argument('-g', '--gpu', nargs='+', type=int, default=0,
                        help='GPU ids to be used. '
                        'int ("0", "1") or list of int ("1 2", "0 1 2"). '
                        'default "0"')

    parser.add_argument('--n_epochs', type=int, default=10,
                        help='number of epochs. default to 10')
    parser.add_argument('-lr', type=float, default=0.0001,
                        help='learning rate. default to 0.0001')
    parser.add_argument('--betas', nargs='+', type=float, default=[0.9, 0.999],
                        help='betas of Adam. default to (0.9, 0.999).'
                        'specify like --betas 0.9 0.999')

    args = parser.parse_args()
    print(args)
    main(args)

学習の準備

次はmain()で一般的な学習ループとその準備を行います.

deviceの設定
    assert torch.cuda.is_available(), 'cpu is not supported'
    if isinstance(args.gpu, int):
        device = torch.device('cuda:' + str(args.gpu))
    elif isinstance(args.gpu, list):
        device = torch.device('cuda:' + str(args.gpu[0]))  # the 1st device
  • args.gpuがintならsingle GPUで実行
  • args.gpuがintのListならdata parallelで実行
shardファイルパスのリスト
     shards_path = [
        str(path) for path in Path(args.shard_path).glob('*.tar')
        if not path.is_dir()
    ]

読み込むshardファイル一覧を取得します.

datasetオブジェクト作成
    transform = get_transform()

    dataset = make_dataset(
        shards_url=shards_path,
        batch_size=args.batch_size,
        shuffle_buffer_size=args.shuffle,
        transform=transform)

transformは一般的なものを設定しているだけです.が,jpegデコードの時点ですでにtorch.tensorにしてあるので,ToTensorは使わずに255で割るだけにしています.

transformの作成
def get_transform():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Lambda(lambda x: x / 255.),  # already tensor
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

以下datasetの説明のために,一旦main()を離れてmake_dataset()を先に説明します.

webdatasetパイプラインの作成

datasetオブジェクトの作成の詳細は以下の記事を参照してください.


def make_dataset(
    shards_url,
    batch_size,
    shuffle_buffer_size=-1,
    transform=None,
):

make_dataset()の定義.

    dataset = wds.WebDataset(shards_url)

shardsファイルリストからwebdatasetオブジェクトを作成.

    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)

シャッフルするならここでshardファイルリストをシャッフル.今回はshard読み込み順を確認するので,シャッフルしません.

    dataset = dataset.decode('torchrgb')  # jpg --> tensor(uint8, CHW)

自動デコード方法の指定.ここではjpgファイルをtorch.tensorにする設定にします.

    dataset = dataset.to_tuple(
        'jpg',
        'json',
        '__url__',
    )

jpgとjsonの2つに加えて,読み込んだshardファイル名も__url___で取得できるので,ここでその3つをタプルにします.

    dataset = dataset.map_tuple(
        lambda x: transform(x) if transform is not None else x,
        add_worker_id,
        lambda x: int(x.split('.')[0].split('-')[-1]),  # 'test-00.tar' --> 0
    )

タプルの要素それぞれに適用する関数の指定.

  • jpegをデコードしたtensorにはtransformを適用.
  • json情報をデコードした辞書には関数add_workder_idを適用(後述).
  • shardファイル名は連番の番号を抜き出してintに変換.

このタプルがcollate関数に送られます.

    dataset = dataset.batched(
        batch_size,
        partial=False)

    return dataset

最後にdatasetオブジェクトの時点でバッチサイズを設定します.

workder idの追加

json情報をデコードした辞書には関数add_workder_idを適用しました.

def add_worker_id(sample):
    info = torch.utils.data.get_worker_info()
    sample['read worker id'] = info.id
    return sample

この関数では,どの並列ワーカーがこのサンプルを処理したのかを確認するために,torch.utils.data.get_worker_info()でワーカーIDを取得して,辞書に追加しています.

collate関数

jpegファイルはtorch.tensorにデコードしているため,webdatasetのdataset.batched()でバッチを作成した時点ですでにBCHW形式になっています.

そうすると,collate関数に送られるタプルの要素の

  • 0番目:すでにBCHW
  • 1番目:まだ
  • 2番目:まだ
    ということになるので,このままではデフォルトのtorch.utils.data.default_collate()が扱えません.

そこで以下の自作collate関数でタプルの要素別にtorch.utils.data.default_collate()を適用します.

def my_collate_fn(batch):
    ret = (
        batch[0],  # 'jpg', already BCHW because of dataset.batched()
        torch.utils.data.default_collate(batch[1]),  # 'json'
        torch.utils.data.default_collate(batch[2]),  # '__url__'
    )
    return ret

data loaderの準備

ではdatasetオブジェクトが作成できたので,再びmain()に戻って続きを説明します.

   sample_loader = wds.WebLoader(
       dataset,
       batch_size=None,
       shuffle=False,
       num_workers=args.num_workers,
       pin_memory=True,
       collate_fn=my_collate_fn)

作成したdatasetオブジェクトを使って,webdatasetのWebLoaderに渡します.

  • batch_size=None: バッチサイズは設定しない.dataset.batched()で設定済みなので.
  • shuffle=False: dataset.shuffle()で設定済み.
  • あとはtorch.DataLoaderにわたす引数と同じ.(pin_memoryなど)

wds.WebLoaderについては以下の記事を参照.

    dataset_size, num_classes = info_from_json(args.shard_path)
    num_batches = dataset_size // args.batch_size + 1

    sample_loader.length = num_batches
    sample_loader = sample_loader.with_length(num_batches)
  • shardにはデータセットサイズの情報がないので,shard作成時に別途jsonを作成しておき,それを読み込みます(手動指定でもよいですが).
  • num_batchesは1エポック分のバッチ数です.ここでは「データ数/バッチサイズ+1」にしてあります.
  • sample_loader.lengthには手動でバッチ数を指定します.この数だけバッチを取り出すと,1エポックが終わります.
    • 指定しないと,loaderのforループが終了せず(エポック終了になない),無限にshardから繰り返しデータを読み込むことになります.
  • sample_loader.with_length()の設定はオプションです.len(sample_loader)で取得できる値を設定しているだけのようです.

モデルの準備

    model = MyModel(num_classes=num_classes)
    if isinstance(args.gpu, list):
        model = torch.nn.DataParallel(model, device_ids=args.gpu)
    model.to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr, betas=args.betas)

    train_loss = AverageMeter()
    train_top1 = AverageMeter()

このあたりは一般的なモデルの準備です.

学習ループ

では学習ループの設定です.進捗表示のtqdmと,共有lockオブジェクトを生成します.

    with tqdm(range(args.n_epochs)) as pbar_epoch, \
            SyncManager() as manager:

        lock = manager.Lock()

        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Train] epoch: %d" % epoch)

            train_loss.reset()
            train_top1.reset()

            with tqdm(enumerate(sample_loader),
                      total=sample_loader.length,
                      leave=True,
                      smoothing=0,
                      ) as pbar_batch:

各ワーカーで情報をprint()で表示するときには,このlockを使って同時に表示することのないようにします.

                for i, batch in pbar_batch:

                    im, batch_dic, urls = batch
                    im = im.to(device)
                    label = batch_dic['label'].to(device)

                    batch_dic['url'] = urls
                    gpu_id = im.get_device()

                    optimizer.zero_grad()

ここまでは普通のバッチ取得です.

                    print('==========================')
                    print(f'loop {i} on GPU {gpu_id}:')
                    print('worker id', batch_dic['read worker id'])
                    print('shard    ', batch_dic['url'])
                    print('count    ', batch_dic['count'])
                    print('label    ', batch_dic['label'])

                    output, gpu_id = model(im, batch_dic, lock)
                    print('proc GPU ', gpu_id)
  • サンプルの情報を表示します.ここはメインプロセスしか動かないので,print時もlockする必要はありません.
  • modelのforwardを実行します.
  • バッチ内の各サンプルがどのGPUで処理されたのかを表示します.

設定を変えてこの表示がどう変わるのかを,後で比較します.

モデルの中身

では一旦main()を離れて,モデルのforwardを説明します.

class MyModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, im, batch_dic, lock):
        bs = im.shape[0]
        gpu_id = im.get_device()
        gpu_id = torch.tensor([gpu_id] * bs, device=im.device)

        with lock:
            print('|-GPU------', gpu_id, '-----------------')
            print('| worker id', batch_dic['read worker id'])
            print('| shard    ', batch_dic['url'])
            print('| count    ', batch_dic['count'])
            # print('| label    ', batch_dic['label'])

        return self.model(im), gpu_id
  • このモデルは事前学習済みResNet18をfine tuneするだけのものです.
  • forward()の内部でlockをかけて,必要な情報をprint()で表示します.
  • gpu_idは,渡されたデータimのデバイス番号です.これでこのサンプルがどのGPUで処理されたのかが把握できます.
    • gpu_idをリストにして更にtensorにしてるのは,単に他の情報と同様にtensorバッチとして表示したいという,理由だけです.

損失計算

ではmain()に戻って,最後の損失計算と表示です.

                    loss = criterion(output, label)
                    loss.backward()
                    optimizer.step()

                    bs = im.size(0)
                    train_loss.update(loss, bs)
                    train_top1.update(accuracy(output, label), bs)

                    pbar_batch.set_postfix_str(
                        ' loss={:6.04f}/{:6.04f}'
                        ' top1={:6.04f}/{:6.04f}'
                        ''.format(
                            train_loss.value, train_loss.avg,
                            train_top1.value, train_top1.avg,
                        ))

では実行してみる

single GPUの場合

まずは1 GPUで実行して,どのshardからどの順番にサンプルが取得されて学習に使用されているのかを把握します.

  • GPU番号:0
  • ワーカー数:2
  • バッチサイズ: 7
0ループ目
$ python use_jpeg_shards_singe_or_DP.py -s ./shards_cats_dogs/ -g 0 -b 7 -w 2 
[Train] epoch: 0:   0%|           | 0/10 [00:00<?, ?it/s
==========================        | 0/3572 [00:00<?, ?it/s]
loop 0 on GPU 0:
worker id tensor([0, 0, 0, 0, 0, 0, 0])
shard     tensor([18, 18, 18, 18, 18, 18, 18])
count     tensor([0, 1, 2, 3, 4, 5, 6])
label     tensor([1, 0, 0, 1, 1, 0, 0])
|-GPU------ tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
| shard     tensor([18, 18, 18, 18, 18, 18, 18], device='cuda:0')
| count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')
proc GPU  tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')

まず0ループ目.メインループの情報を見ると,

  • worker idは0なので,0番目のワーカーがshardを読み込んだことがわかります.
  • shardは18.つまりtest-0018.tarという名前のshardから読み込まれています.
  • countは0から6まで連続しているので,つまりtest-0018.tarの先頭のサンプルから順番に読み込まれていることがわかります.

次にモデル内部での表示を見ると

  • GPUがすべて0.これは指定したとおりです.
  • worker id,shard,countはメインループの情報がそのまま渡されているだけです.

モデルから返ってきた情報(proc GPU)をみると,

  • GPUがすべて0.これは指定したとおりです.
1ループ目
==========================                  | 1/3572 [04:49<287:24:43, 289.75s/it,  loss=0.6981/0.6981 top1=57.1429/57.1429]
loop 1 on GPU 0:
worker id tensor([1, 1, 1, 1, 1, 1, 1])
shard     tensor([19, 19, 19, 19, 19, 19, 19])
count     tensor([0, 1, 2, 3, 4, 5, 6])
label     tensor([1, 0, 0, 1, 1, 0, 0])
|-GPU------ tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:0')
| shard     tensor([19, 19, 19, 19, 19, 19, 19], device='cuda:0')
| count     tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')
proc GPU  tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')

次は1ループ目.

  • worker idが1になっています.ワーカー数が2なので,worker0とworker1が交互にバッチを取得していることがわかります.
  • shardは19,countは0から6なので,test-00019.tarというshardの先頭から連続してサンプルを取得しています. 
2ループ目
=========================                  | 2/3572 [06:44<200:31:25, 202.21s/it,  loss=0.6932/0.6956 top1=42.8571/50.0000]
loop 2 on GPU 0:
worker id tensor([0, 0, 0, 0, 0, 0, 0])
shard     tensor([18, 18, 18, 18, 18, 18, 18])
count     tensor([ 7,  8,  9, 10, 11, 12, 13])
label     tensor([1, 1, 0, 0, 1, 1, 0])
|-GPU------ tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')
| shard     tensor([18, 18, 18, 18, 18, 18, 18], device='cuda:0')
| count     tensor([ 7,  8,  9, 10, 11, 12, 13], device='cuda:0')
proc GPU  tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')

2ループ目.

  • worker idは再び0に戻って,
  • worker0が読み込んでいた同じshard 18から,
  • 引き続きサンプルを読み込んでいることがわかります.(countが7からになっている)
3ループ目
=========================                 | 3/3572 [06:53<136:29:48, 137.68s/it,  loss=0.5293/0.6402 top1=100.0000/66.6667]
loop 3 on GPU 0:
worker id tensor([1, 1, 1, 1, 1, 1, 1])
shard     tensor([19, 19, 19, 19, 19, 19, 19])
count     tensor([ 7,  8,  9, 10, 11, 12, 13])
label     tensor([1, 1, 0, 0, 1, 1, 0])
|-GPU------ tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:0')
| shard     tensor([19, 19, 19, 19, 19, 19, 19], device='cuda:0')
| count     tensor([ 7,  8,  9, 10, 11, 12, 13], device='cuda:0')
proc GPU  tensor([0, 0, 0, 0, 0, 0, 0], device='cuda:0')

3ループ目.

  • worker idは再び1に戻って,
  • worker1が読み込んでいた同じshard 19から,
  • 引き続きサンプルを読み込んでいることがわかります.(countが7からになっている)

後はこれの繰り返しです.
つまり1 GPU,複数ワーカーの場合,各ワーカーそれぞれ別のshardからサンプルを読み込んで,交代でバッチを生成していることがわかります.

3 GPUでdata parallelの場合

まずは3 GPUで実行してみます.

  • GPU番号:0, 1, 2
  • ワーカー数:2
  • バッチサイズ: 14
0ループ目
$ python use_jpeg_shards_singe_or_DP.py -s ./shards_cats_dogs/ -g 0 1 2 -b 14 -w 2 
Namespace(batch_size=14, betas=[0.9, 0.999], gpu=[0, 1, 2], lr=0.0001, n_epochs=10, num_workers=2, shard_path='./shards_cats_dogs/', shuffle=-1)
[Train] epoch: 0:   0%|        | 0/10 [00:00<?, ?it/s
==========================     | 0/1786 [00:00<?, ?it/s]
loop 0 on GPU 0:
worker id tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
shard     tensor([18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18])
count     tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13])
label     tensor([1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0])
|-GPU------ tensor([1, 1, 1, 1, 1], device='cuda:1') -----------------
| worker id tensor([0, 0, 0, 0, 0], device='cuda:1')
| shard     tensor([18, 18, 18, 18, 18], device='cuda:1')
| count     tensor([5, 6, 7, 8, 9], device='cuda:1')
|-GPU------ tensor([2, 2, 2, 2], device='cuda:2') -----------------
| worker id tensor([0, 0, 0, 0], device='cuda:2')
| shard     tensor([18, 18, 18, 18], device='cuda:2')
| count     tensor([10, 11, 12, 13], device='cuda:2')
|-GPU------ tensor([0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([0, 0, 0, 0, 0], device='cuda:0')
| shard     tensor([18, 18, 18, 18, 18], device='cuda:0')
| count     tensor([0, 1, 2, 3, 4], device='cuda:0')
proc GPU  tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2], device='cuda:0')

まず0ループ目.メインループの情報を見ると,

  • worker idは0なので,0番目のワーカーがshardを読み込んだことがわかります.
  • shardは18.つまりtest-0018.tarという名前のshardから読み込まれています.
  • countは0から13まで連続しているので,つまりtest-0018.tarの先頭のサンプルから順番に読み込まれていることがわかります.

ここまではsingle GPUと同じです.

次にモデル内部での表示を見ると,GPU 3つ分が表示されます.(並列処理してますが,lockを使っているので表示はGPU毎にまとまっています)

  • GPU 1
    • GPUはすべて1.当然ですね.
    • worker id,shard,countはメインループの情報がそのまま渡されているだけです.
    • ただしcountが5〜9になっています.
  • GPU 2
    • GPUはすべて2.
    • worker id,shard,countはメインループの情報がそのまま渡されているだけです.
    • ただしcountが10〜13になっています.
  • GPU 0
    • GPUはすべて0.
    • worker id,shard,countはメインループの情報がそのまま渡されているだけです.
    • ただしcountが0〜4になっています.

モデルから返ってきた情報(proc GPU)をみると,

  • バッチサイズ17のうち
    • 最初の5サンプルはGPU 0で処理
    • 次の5サンプルはGPU 1で処理
    • 最後の4サンプルはGPU 2で処理

されたことが分かります.

つまりdata parallelでは,

  • サンプルバッチの生成は1つのワーカーが行い,
  • バッチを分割して,
  • それぞれを各GPU上のモデルに渡し,
  • 返ってきた情報をまた1つのバッチにしている

ということが分かります.

ちなみにワーカー数=コア数(num_workers=os.cpu_count())程度にしておけばよいことになります.

1ループ目
loop 1 on GPU 0:
worker id tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
shard     tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19])
count     tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13])
label     tensor([1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0])
|-GPU------ tensor([0, 0, 0, 0, 0], device='cuda:0') -----------------
| worker id tensor([1, 1, 1, 1, 1], device='cuda:0')
| shard     tensor([19, 19, 19, 19, 19], device='cuda:0')
| count     tensor([0, 1, 2, 3, 4], device='cuda:0')
|-GPU------ tensor([1, 1, 1, 1, 1], device='cuda:1') -----------------
| worker id tensor([1, 1, 1, 1, 1], device='cuda:1')
| shard     tensor([19, 19, 19, 19, 19], device='cuda:1')
| count     tensor([5, 6, 7, 8, 9], device='cuda:1')
|-GPU------ tensor([2, 2, 2, 2], device='cuda:2') -----------------
| worker id tensor([1, 1, 1, 1], device='cuda:2')
| shard     tensor([19, 19, 19, 19], device='cuda:2')
| count     tensor([10, 11, 12, 13], device='cuda:2')
proc GPU  tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2], device='cuda:0')

以降のループも同様です.

  • ワーカーが交代でバッチを生成するのはsingle GPUの場合と同じです.
  • バッチを分割して各GPUに送り,結果をまた1つのバッチに統合するのも,0ループ目と同様です.

lockなしクリーンなコード全体

lockを使わない普通の場合の全体のコードのサンプルです.折りたたんでおきます.
全体コードのuse_jpeg_shards_singe_or_DP_clean.py

from pathlib import Path
from torch import optim, nn
from torchvision import transforms, models
from tqdm import tqdm
import argparse
import json
import torch
import webdataset as wds

import warnings
warnings.simplefilter('ignore', UserWarning)


class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L363-L380
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.value = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, value, bs=1):
        if isinstance(value, torch.Tensor):
            value = value.item()
        self.value = value
        self.sum += value * bs
        self.count += bs
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified values of k
    https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L411
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res if len(res) > 1 else res[0]


def info_from_json(shard_path):
    json_file = Path(shard_path).glob('*.json')
    json_file = str(next(json_file))  # get the first json file
    with open(json_file, 'r') as f:
        info_dic = json.load(f)

    return info_dic['dataset size'], info_dic['num_classes']


def get_transform():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Lambda(lambda x: x / 255.),  # already tensor
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])


class MyModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, im):
        return self.model(im)


def make_dataset(
    shards_url,
    batch_size,
    shuffle_buffer_size=-1,
    transform=None,
):

    dataset = wds.WebDataset(shards_url)
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.decode('torchrgb')  # jpg --> tensor(uint8, CHW)
    dataset = dataset.to_tuple(
        'jpg',
        'json',
    )
    dataset = dataset.map_tuple(
        lambda x: transform(x) if transform is not None else x,
        lambda x: x['label']
    )
    dataset = dataset.batched(
        batch_size,
        partial=False)

    return dataset


def my_collate_fn(batch):
    ret = (
        batch[0],  # 'jpg', already BCHW because of dataset.batched()
        torch.utils.data.default_collate(batch[1]),  # label
    )
    return ret


def main(args):

    assert torch.cuda.is_available(), 'cpu is not supported'
    if isinstance(args.gpu, int):
        device = torch.device('cuda:' + str(args.gpu))
    elif isinstance(args.gpu, list):
        device = torch.device('cuda:' + str(args.gpu[0]))  # the 1st device

    shards_path = [
        str(path) for path in Path(args.shard_path).glob('*.tar')
        if not path.is_dir()
    ]

    transform = get_transform()

    dataset = make_dataset(
        shards_url=shards_path,
        batch_size=args.batch_size,
        shuffle_buffer_size=args.shuffle,
        transform=transform)
    sample_loader = wds.WebLoader(
        dataset,
        batch_size=None,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=my_collate_fn)

    dataset_size, num_classes = info_from_json(args.shard_path)
    num_batches = dataset_size // args.batch_size + 1

    sample_loader.length = num_batches
    sample_loader = sample_loader.with_length(num_batches)

    model = MyModel(num_classes=num_classes)
    if isinstance(args.gpu, list):
        model = torch.nn.DataParallel(model, device_ids=args.gpu)
    model.to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr, betas=args.betas)

    train_loss = AverageMeter()
    train_top1 = AverageMeter()

    with tqdm(range(args.n_epochs)) as pbar_epoch:

        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Train] epoch: %d" % epoch)

            train_loss.reset()
            train_top1.reset()

            with tqdm(enumerate(sample_loader),
                      total=sample_loader.length,
                      leave=True,
                      smoothing=0,
                      ) as pbar_batch:

                for i, batch in pbar_batch:

                    im, label = batch
                    im = im.to(device)
                    label = label.to(device)

                    optimizer.zero_grad()

                    output = model(im)

                    loss = criterion(output, label)
                    loss.backward()
                    optimizer.step()

                    bs = im.size(0)
                    train_loss.update(loss, bs)
                    train_top1.update(accuracy(output, label), bs)

                    pbar_batch.set_postfix_str(
                        ' loss={:6.04f}/{:6.04f}'
                        ' top1={:6.04f}/{:6.04f}'
                        ''.format(
                            train_loss.value, train_loss.avg,
                            train_top1.value, train_top1.avg,
                        ))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-s', '--shard_path', action='store',
                        default='./test_shards/',
                        help='Path to the dir to store shard *.tar files.')
    parser.add_argument('--shuffle', type=int, default=-1,
                        help='shuffle buffer size. negative means no shuffle. '
                        'default -1')

    parser.add_argument('-b', '--batch_size', type=int, default=3,
                        help='batch size. default 3')
    parser.add_argument('-w', '--num_workers', type=int, default=2,
                        help='number of dataloader workders. default 2')
    parser.add_argument('-g', '--gpu', nargs='+', type=int, default=0,
                        help='GPU ids to be used. '
                        'int ("0", "1") or list of int ("1 2", "0 1 2"). '
                        'default "0"')

    parser.add_argument('--n_epochs', type=int, default=10,
                        help='number of epochs. default to 10')
    parser.add_argument('-lr', type=float, default=0.0001,
                        help='learning rate. default to 0.0001')
    parser.add_argument('--betas', nargs='+', type=float, default=[0.9, 0.999],
                        help='betas of Adam. default to (0.9, 0.999).'
                        'specify like --betas 0.9 0.999')

    args = parser.parse_args()
    print(args)
    main(args)
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?