PyTorch Dataset と DataLoader の使い方
PyTorchを使うと、データセットの処理や学習データのバッチ処理が非常に簡単になります。その中心的な要素として、Dataset
と DataLoader
があります。このチュートリアルでは、これらの基本的な使い方について段階的に説明していきます。
目次
- PyTorchの
Dataset
クラスを使ってデータを準備する -
DataLoader
を使ってデータをバッチ処理する - サンプルコードを使って実際にデータをロードする
-
Tensor
データを扱う場合の例
Step 1: PyTorchのDataset
クラスを使ってデータを準備する
Dataset
クラスは、データを効率的に処理するための基本的な構造を提供します。PyTorchのDataset
クラスを継承して、カスタムデータセットを作成します。
以下のコードでは、簡単なカスタムデータセットを作成します。このデータセットは、入力データとラベルを受け取り、それらを効率的にアクセスできるようにします。
import torch
from torch.utils.data import Dataset
import numpy as np
# カスタムデータセットの作成
class CustomDataset(Dataset):
def __init__(self, data, labels):
# データとラベルを受け取るコンストラクタ
self.data = data
self.labels = labels
def __len__(self):
# データセットのサイズを返す
return len(self.data)
def __getitem__(self, idx):
# 指定したインデックスのデータとラベルを返す
sample = {'data': self.data[idx], 'label': self.labels[idx]}
return sample
解説
-
__init__
: コンストラクタでデータとラベルを受け取り、内部で保持します。このメソッドは、データセットのインスタンスを作成するときに一度だけ呼び出されます。 -
__len__
: データセットのサイズを返すメソッドです。これはデータが何サンプルあるのかを教えてくれます。 -
__getitem__
: 特定のインデックスのデータとラベルを返します。データローダーはこのメソッドを使ってデータを取得します。
Step 2: データセットのインスタンス化
次に、このカスタムデータセットを使って、実際のデータを処理します。ここでは、簡単な例として100個のサンプルデータと、そのラベルを準備します。
# ダミーデータを作成
data = np.array([[i] for i in range(100)], dtype=np.float32)
labels = np.array([i % 2 for i in range(100)], dtype=np.float32)
# データセットのインスタンスを作成
custom_dataset = CustomDataset(data, labels)
ここでは、data
は0から99までの値を含む配列で、labels
は0か1のラベル(偶数・奇数を判別するため)です。
Step 3: DataLoaderを使ってデータをバッチ処理する
DataLoader
は、データセットを小さなバッチに分割して扱うためのクラスです。データをシャッフルしたり、並列処理を使ってデータを効率的に読み込むことができます。
from torch.utils.data import DataLoader
# DataLoaderの作成
data_loader = DataLoader(dataset=custom_dataset, batch_size=10, shuffle=True)
解説
-
dataset
: 先ほど作成したカスタムデータセットを指定します。 -
batch_size
: 1回にロードするデータの数を指定します。ここでは10個ずつデータをロードします。 -
shuffle
: データをシャッフルするかどうかを指定します。True
にすると、エポックごとにデータの順序がランダムになります。
Step 4: DataLoaderでデータをイテレートする
DataLoaderを使うと、バッチ単位でデータを扱うことができます。以下のコードは、DataLoaderを使ってデータを繰り返し処理する例です。
# DataLoaderを使ってデータをイテレート
for batch in data_loader:
print(f"Data: {batch['data']}, Label: {batch['label']}")
このコードでは、DataLoaderを使ってデータをバッチ単位で取得し、データとラベルを出力します。バッチ処理により、モデルのトレーニングや推論に使いやすい形式でデータを提供します。
Step 5: DatasetとDataLoaderの解説
-
Dataset
: データとラベルを保持し、特定のインデックスのサンプルを取得するためのメソッドを提供します。データセット全体のサイズを返すメソッドも含まれます。 -
DataLoader
: データセットを小さなバッチに分割し、効率的にデータを読み込むクラスです。データをシャッフルしたり、並列処理でデータの読み込みを行えます。
Step 6: Tensorデータを扱う場合
次に、torch.Tensor
を使ったデータセットの例を紹介します。numpy
ではなく、直接PyTorchのテンソルを使用してデータを管理する場合に役立ちます。
class TensorDataset(Dataset):
def __init__(self, tensor_data, tensor_labels):
self.tensor_data = tensor_data
self.tensor_labels = tensor_labels
def __len__(self):
return self.tensor_data.size(0)
def __getitem__(self, idx):
return self.tensor_data[idx], self.tensor_labels[idx]
# Tensorを作成
tensor_data = torch.arange(100, dtype=torch.float32).view(-1, 1)
tensor_labels = torch.arange(100, dtype=torch.float32) % 2
# テンソルデータセットのインスタンスを作成
tensor_dataset = TensorDataset(tensor_data, tensor_labels)
tensor_loader = DataLoader(dataset=tensor_dataset, batch_size=10, shuffle=True)
# Tensorデータセットを使ってイテレート
for data, label in tensor_loader:
print(f"Tensor Data: {data}, Tensor Label: {label}")
ここでは、torch.arange
を使ってテンソル形式のデータを生成し、同様にDataLoader
を使ってバッチ処理しています。
まとめ
このチュートリアルでは、PyTorchのDataset
とDataLoader
の基本的な使い方を紹介しました。カスタムデータセットを作成し、バッチ処理やシャッフルなどをDataLoader
を通じて簡単に実現できることを確認しました。また、テンソルデータを使った例も紹介し、PyTorchの柔軟性を活かして様々なデータセット形式を扱う方法も解説しました。