pytorchでGPU2つ以上使う方法
日本語で書いてる記事が少なかったから投稿
multiGPU.py
import torch.nn as nn
#modelの定義した後に
model = nn.DataParallel(model).cuda()
modelにAlexnetなりResnetなりを読み込ませた後nn.DataParallelでラップしてあげるだけでマルチGPU使用可能
More than 5 years have passed since last update.
pytorchでGPU2つ以上使う方法
日本語で書いてる記事が少なかったから投稿
import torch.nn as nn
#modelの定義した後に
model = nn.DataParallel(model).cuda()
modelにAlexnetなりResnetなりを読み込ませた後nn.DataParallelでラップしてあげるだけでマルチGPU使用可能
Register as a new user and use Qiita more conveniently