LoginSignup
0
2

More than 3 years have passed since last update.

深層学習(Pytorch)を用いた、Kaggle Titanic実践 PART 4 (Pytorchのdataloaderとdataset)

Posted at

この章では、PytorchのDatasetとDataLoaderについて解説していきます。
この章はhttps://gotutiyan.hatenablog.com/entry/2020/04/21/182937 を参考に記述されています。
Pytorchでは、DatasetとDataLoaderを用いることで、簡単にミニバッチ化をすることができます。

Datasetの実装

DataSetを実装する際には、クラスのメンバ関数としてlen()とgetitem()を必ず作ります。

len()は、len()を使ったときに呼ばれる関数です。
getitem()は、array[i]のように[ ]を使って要素を参照するときに呼ばれる関数です。これが呼ばれる際には、必ず何かしらのindexが指定されているので、引数にindexの情報を取ります。また、入出力のペアを返すように設計します。

以上を踏まえて、Datasetを作成してみましょう。

class DataSet:
    def __init__(self):
        self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # 入力
        self.t = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] # 出力

    def __len__(self):
        return len(self.X) # データ数(10)を返す

    def __getitem__(self, index):
        # index番目の入出力ペアを返す
        return self.X[index], self.t[index]

さて、実際にこのDataSetがどのような振る舞いをするか試してみましょう。

dataset = DataSet()
print('全データ数:',len(dataset))  # 全データ数: 10
print('3番目のデータ:',dataset[3]) # 3番目のデータ: (3, 1)
print('5~6番目のデータ:',dataset[5:7]) # 5~6番目のデータ: ([5, 6], [1, 0])

DataLoaderの実装

バッチサイズを2、訓練時のデータのシャッフルをFalseとした実装は以下のようになります。

# さっき作ったDataSetクラスのインスタンスを作成
dataset = DataSet()
# datasetをDataLoaderの引数とすることでミニバッチを作成.
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)

これでミニバッチ学習をする準備が整いました。
ミニバッチ用のデータはfor文で取り出すことができます。

for data in dataloader:
    print(data)

'''
出力:
[tensor([0, 1]), tensor([0, 1])]
[tensor([2, 3]), tensor([0, 1])]
[tensor([4, 5]), tensor([0, 1])]
[tensor([6, 7]), tensor([0, 1])]
[tensor([8, 9]), tensor([0, 1])]
'''

上記のdataloaderを用いて、10epoch学習をする場合には以下のように書けます。

epoch = 10
model = #何かしらのモデル
for _ in range(epoch):
    for data in dataloader:
        X = data[0]
        t = data[1]
        y = model(X)
        # lossの計算とか

PytorchのDatasetとDataloaderについての説明は、以上になります。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2