この章では、PytorchのDatasetとDataLoaderについて解説していきます。
この章はhttps://gotutiyan.hatenablog.com/entry/2020/04/21/182937 を参考に記述されています。
Pytorchでは、DatasetとDataLoaderを用いることで、簡単にミニバッチ化をすることができます。
Datasetの実装
DataSetを実装する際には、クラスのメンバ関数として__len__()と__getitem__()を必ず作ります。
len()は、len()を使ったときに呼ばれる関数です。
getitem()は、array[i]のように[ ]を使って要素を参照するときに呼ばれる関数です。これが呼ばれる際には、必ず何かしらのindexが指定されているので、引数にindexの情報を取ります。また、入出力のペアを返すように設計します。
以上を踏まえて、Datasetを作成してみましょう。
class DataSet:
def __init__(self):
self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # 入力
self.t = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] # 出力
def __len__(self):
return len(self.X) # データ数(10)を返す
def __getitem__(self, index):
# index番目の入出力ペアを返す
return self.X[index], self.t[index]
さて、実際にこのDataSetがどのような振る舞いをするか試してみましょう。
dataset = DataSet()
print('全データ数:',len(dataset)) # 全データ数: 10
print('3番目のデータ:',dataset[3]) # 3番目のデータ: (3, 1)
print('5~6番目のデータ:',dataset[5:7]) # 5~6番目のデータ: ([5, 6], [1, 0])
DataLoaderの実装
バッチサイズを2、訓練時のデータのシャッフルをFalseとした実装は以下のようになります。
# さっき作ったDataSetクラスのインスタンスを作成
dataset = DataSet()
# datasetをDataLoaderの引数とすることでミニバッチを作成.
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)
これでミニバッチ学習をする準備が整いました。
ミニバッチ用のデータはfor文で取り出すことができます。
for data in dataloader:
print(data)
'''
出力:
[tensor([0, 1]), tensor([0, 1])]
[tensor([2, 3]), tensor([0, 1])]
[tensor([4, 5]), tensor([0, 1])]
[tensor([6, 7]), tensor([0, 1])]
[tensor([8, 9]), tensor([0, 1])]
'''
上記のdataloaderを用いて、10epoch学習をする場合には以下のように書けます。
epoch = 10
model = #何かしらのモデル
for _ in range(epoch):
for data in dataloader:
X = data[0]
t = data[1]
y = model(X)
# lossの計算とか
PytorchのDatasetとDataloaderについての説明は、以上になります。