0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

実務でよく使うPyTorchの機能まとめ①|DataLoader

Last updated at Posted at 2025-01-20

前提

使用する環境

以下の環境を前提とします。

CUDA Version: 12.6
Python 3.12.7
torch==2.5.1+cu124

CUDAやPyTorchのインストール手順の例については下記で取りまとめましたので、インストールがまだの方は下記などをご確認ください。

DataLoaderの基本コード

DataLoaderPyTorchでバッチ処理を行う際によく用いる機能です。下記のようなサンプルコードに基づいてバッチの作成をおこないます。

まず、サンプルデータにMNISTを用いるにあたって下記を実行します。

import torch
import torchvision
import torchvision.transforms as transforms

mnist = torchvision.datasets.MNIST(root='./data',
                                        train=True,
                                        transform=transforms.ToTensor(),
                                        download=True)
print(mnist)

・実行結果

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

次に作成したMNISTのオブジェクトのmnistを元に下記のようにバッチの作成を行うことができます。

from torch.utils.data import DataLoader

data_loader = DataLoader(dataset=mnist, batch_size=10, shuffle=True)
data_iter = iter(data_loader)
x, y = next(data_iter)

print(y)

・実行結果

tensor([0, 8, 5, 9, 9, 1, 9, 4, 0, 8])

実行結果よりbatch_sizeに対応する数のラベルが出力されていることが確認できます。

DataLoaderで用いるDataSetの構築

前項では予め用意されたMNISTDataLoaderの引数に与えましたが、下記のようにDataSetのクラスを構築し、DataLoaderの引数に与えることも可能です。

import torch

class DataSet:
    def __init__(self):
        self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        self.t = [0, 1, 0, 1, 0, 1, 1, 0, 0, 1]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.t[index]
        
        
dataset = DataSet()

torch.manual_seed(10) 

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

for data in dataloader:
    print(data)

・実行結果

[tensor([6, 2]), tensor([0, 0])]
[tensor([8, 1]), tensor([0, 1])]
[tensor([5, 3]), tensor([1, 1])]
[tensor([7, 4]), tensor([1, 0])]
[tensor([0, 9]), tensor([0, 1])]

上記より、DataLoaderからサンプリングを行う際の挙動は__getitem__メソッドで定義できることが確認できます。また、結果の再現性を持たせるにあたっては、torch.manual_seed(10)のように乱数を固定すると良いです。乱数を変えた場合について下記で確認します。

class DataSet:
    def __init__(self):
        self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        self.t = [0, 1, 0, 1, 0, 1, 1, 0, 0, 1]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.t[index]
        
        
dataset = DataSet()

torch.manual_seed(20) 

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

for data in dataloader:
    print(data)

・実行結果

[tensor([2, 0]), tensor([0, 0])]
[tensor([8, 5]), tensor([0, 1])]
[tensor([4, 1]), tensor([0, 1])]
[tensor([3, 7]), tensor([1, 1])]
[tensor([9, 6]), tensor([1, 0])]

上記の実行結果より、torch.manual_seed(20)の出力結果はtorch.manual_seed(10)と異なることが確認できます。

バッチ作成におけるサンプルの順番

shuffleオプションの挙動

まずshuffle=FalseでDataLoaderを動かす場合について確認します。

mnist_loader = DataLoader(dataset=mnist, batch_size=10000, shuffle=False)

for i in range(2):
    mnist_iter = iter(mnist_loader)
    for j in range(6):        
        x, y = next(mnist_iter)

        if j==0:
            print(y[:10])

・実行結果

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4])
tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4])

shuffle=Falseでバッチの作成を行う場合、上記のような結果が得られます。結果より、同じ順番でバッチの作成が行われることが確認できます。次に、shuffleオプションを指定しない場合について確認します。

mnist_loader = DataLoader(dataset=mnist, batch_size=10000)

for i in range(2):
    mnist_iter = iter(mnist_loader)
    for j in range(6):        
        x, y = next(mnist_iter)

        if j==0:
            print(y[:10])

・実行結果

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4])
tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4])

上記より、shuffleオプションのデフォルトがFalseであることが確認できます。同様にshuffle=Trueの場合は下記のような結果が得られます。

mnist_loader = DataLoader(dataset=mnist, batch_size=10000, shuffle=True)

for i in range(10):
    mnist_iter = iter(mnist_loader)
    for j in range(6):
        x, y = next(mnist_iter)

        if j==0:
            print(y[:10])

・実行結果

tensor([3, 4, 8, 7, 5, 0, 0, 8, 9, 6])
tensor([3, 7, 2, 4, 4, 9, 1, 1, 8, 7])
tensor([4, 3, 8, 2, 2, 1, 2, 3, 0, 1])
tensor([2, 1, 6, 0, 4, 3, 2, 3, 3, 0])
tensor([2, 9, 7, 5, 9, 6, 6, 4, 4, 0])
tensor([8, 3, 6, 9, 9, 3, 3, 2, 2, 0])
tensor([4, 1, 1, 2, 5, 3, 5, 1, 4, 1])
tensor([9, 9, 6, 9, 5, 4, 5, 3, 6, 8])
tensor([9, 1, 3, 9, 8, 9, 8, 9, 8, 0])
tensor([4, 1, 5, 1, 3, 1, 6, 1, 7, 3])

shuffle=Trueを指定した場合は上記のように出力結果が得られます。バッチの作成にあたって、サンプルの抽出順がシャッフルされていることが結果より確認できます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?