0
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

1台サーバ複数GPU環境でPytorchのDDPに挑戦する

Last updated at Posted at 2023-10-07

はじめに

この記事は1台サーバ複数GPU環境でPytorchのDDPに挑戦する記事です.
複数サーバ複数GPU環境ではないので,ご注意ください.

コード全体

import torch
import torch.distributed
import torch.multiprocessing
import torch.nn.parallel
import os
import datetime

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

def train(rank, num_of_gpu, train_dataset, val_dataset):
    EPOCHS = 30
    BATCH = 6

    assert BATCH % num_of_gpu == 0
    BATCH = BATCH // num_of_gpu

    torch.distributed.init_process_group("nccl", rank = rank, world_size = num_of_gpu)

    net = MyModel()
    net = net.to(rank)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids = [rank], output_device = rank, find_unused_parameters=False)
    
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.0001, betas = (0.9, 0.95))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = EPOCHS, eta_min = 0.000001)
    
    criterion = torch.nn.L1Loss()

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas = num_of_gpu, rank = rank, shuffle = True)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas = num_of_gpu, rank = rank, shuffle = False)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH, shuffle = False, pin_memory=True, num_workers = 2, sampler = train_sampler)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle = False, pin_memory=True, num_workers = 2, sampler = val_sampler)

    for epoch in range(1, EPOCHS + 1):
        torch.distributed.barrier()

        if int(rank) == 0:
            print(f"epoch: {epoch}, datetime: {datetime.datetime.now()}")

        x.append(epoch)
        net.train()
        train_running_loss = 0.0
        train_running_size = 0

        train_sampler.set_epoch(epoch)

        for item in train_loader:
            rgb_images, orange_masks, label = item
            rgb_images = rgb_images.to(rank, non_blocking=True)
            orange_masks = orange_masks.to(rank, non_blocking=True)
            label = label.to(rank, non_blocking=True)

            optimizer.zero_grad()
            outputs = net(rgb_images, orange_masks)
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

            torch.distributed.all_reduce(loss)
            train_running_loss += torch.mean(loss).item()

            train_running_size += num_of_gpu

        torch.distributed.barrier()

        net.eval()
        with torch.no_grad():
            val_running_loss = 0.0
            for item in val_loader:
                rgb_images, orange_masks, label = item
                rgb_images = rgb_images.to(rank, non_blocking=True)
                orange_masks = orange_masks.to(rank, non_blocking=True)
                label = label.to(rank, non_blocking=True)

                outputs = net(rgb_images, orange_masks)
                loss = criterion(outputs, label)

                torch.distributed.all_reduce(loss)
                val_running_loss += torch.mean(loss).item()   
            
            if int(rank) == 0:
                print(f"train_running_loss : {train_running_loss / len(train_dataset)}, val_running_loss : {val_running_loss / len(val_dataset)}")
        
        torch.distributed.barrier()
        scheduler.step()
    
    torch.distributed.barrier()
    if int(rank) == 0:
        net = net.to('cpu')
        torch.save(net.module.state_dict(), f"latest_model.pth")

    torch.distributed.destroy_process_group()

if __name__=="__main__":
    train_dataset, val_dataset = getMyDataset()
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '50000'
    os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    torch.multiprocessing.spawn(train, args=(torch.cuda.device_count(), train_dataset, val_dataset), nprocs = torch.cuda.device_count(), join = True)

解説

if name=="main"部分について

if __name__=="__main__":
    # 事前にDatasetを作成しておく
    # 無用なトラブルを避けるため
    # train関数内に書いても問題ない
    train_dataset, val_dataset = getMyDataset()

    # DDPを行うための儀式的なコード
    # os.environ['TORCH_DISTRIBUTED_DEBUG']はコードに不具合が出てくるとログを出してくれる
    # os.environ['CUDA_VISIBLE_DEVICES']でサーバ上にあるどのGPUを使うか決めることができる
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0,3'とすることも可能.このとき,torch.cuda.device_count()は2を返すようになる
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '50000'
    os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

    # torch.multiprocessing.spawnでtrain関数を並列実行する
    # argsでtrain関数に色々な引数を渡すことができる
    torch.multiprocessing.spawn(train, args=(torch.cuda.device_count(), train_dataset, val_dataset), nprocs = torch.cuda.device_count(), join = True)

学習前のセッティングについて

# mainから呼び出したときにはrankはなかった
# rankにはtrain関数を並列実行するときに,pytorch側で勝手に付け加える値が格納される
# print(rank)などをしてみると良い

def train(rank, num_of_gpu, train_dataset, val_dataset):
    # この辺りは自分でconfigを書いておくと良いだろう
    EPOCHS = 30
    BATCH = 6

    # BATCHの意味が全てのGPUでのBATCHという意味なら以下のコードを実行する
    # BATCHの意味が1GPUでのBATCHという意味な以下のコードを消す
    assert BATCH % num_of_gpu == 0
    BATCH = BATCH // num_of_gpu

    # DDPを行うための儀式的コード
    torch.distributed.init_process_group("nccl", rank = rank, world_size = num_of_gpu)

    # モデルをGPUに移すコード
    # 通常はnet.to("cuda:0)とするが,DDPではnet.to(rank)とする
    # find_unused_parametersは学習に使用していない学習パラメータを見つけたときにエラーとするか決めることができる
    # 自作モデルの場合,find_unused_parameters=Trueにしておくと,どこかにバグがあったときに発見することができる.
    # 他者の書いたコードを改変するときは「find_unused_parameters=Trueにしておき,自分の書いた部分がエラーとして出力されていなければ,find_unused_parameters=Falseにする」ということもできる
    net = MyModel()
    net = net.to(rank)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids = [rank], output_device = rank, find_unused_parameters=False)

    # optimizerなどはいつもと同じコードで良い
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.0001, betas = (0.9, 0.95))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = EPOCHS, eta_min = 0.000001)
    criterion = torch.nn.L1Loss()

    # DDPを使う場合はsamplerを変える必要がある
    # GPUを並列で使うため,データセットから同じサンプルを取得しないようにするためである
    # samplerでshuffleをするかしないか決めることに注意
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas = num_of_gpu, rank = rank, shuffle = True)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas = num_of_gpu, rank = rank, shuffle = False)

    # samplerの設定とshuffleをFalseにすることを忘れない
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH, shuffle = False, pin_memory=True, num_workers = 2, sampler = train_sampler)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle = False, pin_memory=True, num_workers = 2, sampler = val_sampler)

学習時のコードについて

    for epoch in range(1, EPOCHS + 1):
        # GPU間で一斉に学習を始められるように,GPUの待機をしておく
        # DDPのチュートリアルなどには書かれていないので,以下のコードは消しても良いと思われるが,念のため書いてある
        torch.distributed.barrier()

        # rankが0のときに,epoch数の時間を表示しておく
        # ifを消すと,並列して呼び出されたすべてのコードで実行されるようになる
        # 一度実験して試すと良い
        if int(rank) == 0:
            print(f"epoch: {epoch}, datetime: {datetime.datetime.now()}")

        net.train()
        train_running_loss = 0.0
        train_running_size = 0

        # 学習を始める前にsamplerのset_epochを呼び出しておく
        # これをしないとshuffleが正しく行われなくなる
        train_sampler.set_epoch(epoch)

        for item in train_loader:
            # to(rank)にすること
            rgb_image, label = item
            rgb_image = rgb_image.to(rank, non_blocking=True)
            label = label.to(rank, non_blocking=True)

            # ここはいつもと一緒
            # GPUで並列処理をしていると,backwardは正しく処理されるのか気になると思う
            # backward内で,勝手に勾配値をGPU間で共有するため,torch.distributed.barrier()をする必要がない.
            # 勝手に勾配値をGPU間で共有するため,同じ重みになる
            # ただし,DDPとDPで勾配値の共有の仕方が違うかもしれないので,そこはきちんと自分で調べるようにすること
            optimizer.zero_grad()
            outputs = net(rgb_image)
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

            # GPU間でloss値を共有する
            # torch.mean(loss).item()ではなく,loss.item()で良い
            torch.distributed.all_reduce(loss)
            train_running_loss += torch.mean(loss).item()

            train_running_size += num_of_gpu

        # evaluationではbackwardはすでに終わっているため,以下のコードは消して良い
        torch.distributed.barrier()

        # trainとほぼ一緒
        net.eval()
        with torch.no_grad():
            val_running_loss = 0.0
            for item in val_loader:
                rgb_imagelabel = item
                rgb_image = rgb_image.to(rank, non_blocking=True)
                label = label.to(rank, non_blocking=True)

                outputs = net(rgb_image)
                loss = criterion(outputs, label)

                torch.distributed.all_reduce(loss)
                val_running_loss += torch.mean(loss).item()   
            
            if int(rank) == 0:
                print(f"train_running_loss : {train_running_loss / len(train_dataset)}, val_running_loss : {val_running_loss / len(val_dataset)}")

        # schedulerのstepの呼び出しを忘れないように
        torch.distributed.barrier()
        scheduler.step()

    # rankが0のモデルのみを保存する
    # GPU間で重みの共有はされているため,これで問題ない
    # GPUメモリのまま保存すると,汎用性が低いので,to('cpu')しておくと便利
    # モデルをセーブするときには,net.state_dict()でも保存することは可能
    # ただし,net.module.state_dict()で保存しておくと汎用性が高い
    torch.distributed.barrier()
    if int(rank) == 0:
        net = net.to('cpu')
        torch.save(net.module.state_dict(), f"latest_model.pth")

    # DDPの儀式的コード
    torch.distributed.destroy_process_group()
0
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?