LoginSignup
7
3

More than 5 years have passed since last update.

PyTorch の DataLoader でミニバッチを生成すると CUDA テンソルではなくなる

Posted at

現象

DataLoader で CUDA 指定した1次元テンソルからミニバッチを生成すると、生成されたミニバッチから CUDA 指定が外れます。

検証コード
import torch

print(torch.__version__)

x = torch.FloatTensor([[1, 2], [3, 4], [5, 6], [7, 8]]).cuda()
y = torch.LongTensor([0, 0, 1, 1]).cuda()

dataset = torch.utils.data.TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size = 2, shuffle = True)

for batch_x, batch_y in dataloader:
    print("x: {0}, y: {1}".format(type(batch_x), type(batch_y)))
実行結果
0.3.0b0+591e73e
x: <class 'torch.cuda.FloatTensor'>, y: <class 'torch.LongTensor'>  # 本当は y も torch.cuda.FloatTensor になってほしい
x: <class 'torch.cuda.FloatTensor'>, y: <class 'torch.LongTensor'>

原因

DataLoader のソースコードを見ると、ミニバッチのひとつひとつの要素を取り出してリストに格納し、それらを再結合しミニバッチとして返す、という実装になっています。

ミニバッチとして返すデータを選び

indices = next(self.sample_iter)  # may raise StopIteration

ミニバッチのひとつひとつの要素を取り出してリストに格納し

batch = self.collate_fn([self.dataset[i] for i in indices])

それらを再結合しミニバッチとして返す

if torch.is_tensor(batch[0]):
    return torch.stack(batch, 0, out=out)
elif isinstance(batch[0], int_classes):
    return torch.LongTensor(batch)

上記実装の「ミニバッチのひとつひとつの要素を取り出して」の部分が曲者です。

テンソルは要素を一つ取り出すと次元が一つ下がります。2次元テンソルである x は次元が一つ下がっても1次元テンソルになるだけでテンソルのままですが、1次元テンソルである y は次元が一つ下がるとだたの int 型の数値になってしまい、テンソルではなくなります。

しかし、DataLoader が最終的に返す値はテンソルでなければならないので「それらを再結合しミニバッチとして返す」の中で int 型の値を集めてテンソルを作り直しているのですが、その際に torch.LongTensor が使われています。そのため、結果として DataLoader で1次元テンソルからミニバッチを生成すると、生成されたミニバッチから CUDA 指定が外れてしまいます。

対策

CUDA 指定を外したミニバッチを返す、という実装になっている以上、もう一度 cuda() を呼ぶしかありません。

検証コード
for batch_x, batch_y in dataloader:
    print("x: {0}, y: {1}".format(type(batch_x), type(batch_y.cuda())))
実行結果
x: <class 'torch.cuda.FloatTensor'>, y: <class 'torch.cuda.LongTensor'>
x: <class 'torch.cuda.FloatTensor'>, y: <class 'torch.cuda.LongTensor'>

もしくは、DataLoader「それらを再結合しミニバッチとして返す」を実行する関数を collate_fn で指定できるので、CUDA 指定を外さずにミニバッチを返す関数を自分で定義して指定することもできます。

微妙に気が利かない実装だな……

7
3
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
3