torch.nn.DataParallel
で簡単にPyTorchでGPUを並列に使えることを知ったので、簡単にメモしておきます。
前提
自然言語処理でBERTを使って何らかの分類問題をファインチューニングで解くことを想定します。
簡単なネットワークを下のように定義したとします。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
# 2値分類タスクを仮定します。
self.linear = nn.Linear(768, 2)
def forward(self, input_ids):
vec = self.bert(input_ids)
vec = vec['last_hidden_state']
# 先頭トークンclsのベクトルだけ取得
vec = vec[:,0,:]
vec = vec.view(-1, 768)
vec = self.linear(vec)
# 分類確率を返すようにしてます。
return F.softmax(vec, dim=1)
# ネットワークのインスタンスを宣言する
net = Net()
以下のように事前学習済みのパラメータは小さめ、最後の分類する層はやや大きめの学習率でファインチューニングしたいとします。
# 事前学習済の箇所は学習率小さめ、最後の全結合層は大きめにする。
optimizer = optim.Adam([
{'params': net.bert.parameters(), 'lr': 5e-5},
{'params': net.linear.parameters(), 'lr': 1e-4}
])
torch.nn.DataParallelの使い方
- とても簡単でモデルのインスタンスを
torch.nn.DataParallel
で囲ってやればOK。 -
device_ids
にGPUの番号を配列で指定すれば、指定したGPUを並列でよしなに振り分けてくれる様子 - 例えば
cuda:0
とcuda:1
を並列で使いたいときは以下のようにすればOK。
net = torch.nn.DataParallel(net, device_ids=[0, 1])
これだけでOKと思いきやtorch.nn.DataParallel
でモデルを囲うと、モデルはDataParallel
オブジェクトになって、module
でネットワークが囲われるので、上記のファインチューニングのようにnet.bert
のような参照の仕方はできなくなります。
DataParallel(
(module): Net(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(32000, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
〜省略〜
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(linear): Linear(in_features=768, out_features=2, bias=True)
)
)
ファインチューニングは以下のようにmodule
を介せばOK。
optimizer = optim.Adam([
{'params': net.module.bert.parameters(), 'lr': 5e-5},
{'params': net.module.linear.parameters(), 'lr': 1e-4}
])
あとはいつもどおり、ネットワークをGPUに送ればOK。
cuda
の部分をcuda:0
やcuda:1
と書いてもよしなに並列で処理してくれるようですが、DataParallel
で指定した番号以外(cuda:2
とか)にするとエラーでます。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
これでGPUを並列に使ってくれます。nvidia-smi
で確認すると、指定した0番目と1番目のGPUを使ってくれているようです。
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 32701 C ...virtualenv/dev/bin/python 39057MiB |
| 1 N/A N/A 32701 C ...virtualenv/dev/bin/python 23485MiB |
+-----------------------------------------------------------------------------+
GPUが並列で使えるなら、torch.nn.DataParallel
でネットワークは囲っておく癖つけておこう。
おわり