20
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchでGPUを並列で使えるようにするtorch.nn.DataParallelのメモ

Posted at

torch.nn.DataParallelで簡単にPyTorchでGPUを並列に使えることを知ったので、簡単にメモしておきます。

前提

自然言語処理でBERTを使って何らかの分類問題をファインチューニングで解くことを想定します。

簡単なネットワークを下のように定義したとします。


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
        # 2値分類タスクを仮定します。
        self.linear = nn.Linear(768, 2)

    def forward(self, input_ids):
        vec = self.bert(input_ids)
        vec = vec['last_hidden_state']
        # 先頭トークンclsのベクトルだけ取得
        vec = vec[:,0,:]
        vec = vec.view(-1, 768)
        vec = self.linear(vec)
        # 分類確率を返すようにしてます。
        return F.softmax(vec, dim=1)

# ネットワークのインスタンスを宣言する
net = Net()

以下のように事前学習済みのパラメータは小さめ、最後の分類する層はやや大きめの学習率でファインチューニングしたいとします。

# 事前学習済の箇所は学習率小さめ、最後の全結合層は大きめにする。
optimizer = optim.Adam([
    {'params': net.bert.parameters(), 'lr': 5e-5},
    {'params': net.linear.parameters(), 'lr': 1e-4}
])

torch.nn.DataParallelの使い方

  • とても簡単でモデルのインスタンスをtorch.nn.DataParallelで囲ってやればOK。
  • device_idsにGPUの番号を配列で指定すれば、指定したGPUを並列でよしなに振り分けてくれる様子
  • 例えばcuda:0cuda:1を並列で使いたいときは以下のようにすればOK。
net = torch.nn.DataParallel(net, device_ids=[0, 1])

これだけでOKと思いきやtorch.nn.DataParallelでモデルを囲うと、モデルはDataParallelオブジェクトになって、moduleでネットワークが囲われるので、上記のファインチューニングのようにnet.bertのような参照の仕方はできなくなります。

DataParallel(
  (module): Net(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(32000, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
〜省略〜
      )
      (pooler): BertPooler(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (activation): Tanh()
      )
    )
    (linear): Linear(in_features=768, out_features=2, bias=True)
  )
)

ファインチューニングは以下のようにmoduleを介せばOK。

optimizer = optim.Adam([
    {'params': net.module.bert.parameters(), 'lr': 5e-5},
    {'params': net.module.linear.parameters(), 'lr': 1e-4}
])

あとはいつもどおり、ネットワークをGPUに送ればOK。
cudaの部分をcuda:0cuda:1と書いてもよしなに並列で処理してくれるようですが、DataParallelで指定した番号以外(cuda:2とか)にするとエラーでます。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

これでGPUを並列に使ってくれます。nvidia-smiで確認すると、指定した0番目と1番目のGPUを使ってくれているようです。

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     32701      C   ...virtualenv/dev/bin/python    39057MiB |
|    1   N/A  N/A     32701      C   ...virtualenv/dev/bin/python    23485MiB |
+-----------------------------------------------------------------------------+

GPUが並列で使えるなら、torch.nn.DataParallelでネットワークは囲っておく癖つけておこう。

おわり

20
20
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?