9
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【PyTorch】DataLoaderのshuffleとは

Last updated at Posted at 2022-03-28

なんとなく利用していたDataLoader。さらになんとなく利用していた引数shuffle。本記事では引数shuffleにより、サンプル抽出がどのように変わるのかをコードともに残しておく。
下記の質問に回答できればスルーでOK。

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = True)
Trainer.fit(model, dataloader)

Shuffle Trueの場合

  • dataloader定義時のみサンプルはシャッフルされる?
  • Trainer.fit実行すると、Epoch毎にサンプルはシャッフルされる?

結論

DataLoaderのshuffleは、データセットからサンプルを抽出する際の挙動を決める引数である。DataLoader定義時ではなく、DataLoaderが呼び出されるたびにサンプルはシャッフルされる。Trainer.fit実行すると、Epoch毎にDataLoaderが呼び出され、サンプルはシャッフルされる。

Shuffle Falseの場合

  • データセットの上から順番に、サンプルを抽出

Shuffle Trueの場合

  • データセットからランダムに、サンプルを抽出

詳細

ShuffleをTrueにすることで、すべてのバッチのサンプル抽出はランダムに行われる。
Trainer.fitで学習を進める際には、下記のようにDataLoaderを実装する。

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = True)
Trainer.fit(model, dataloader)

私が勘違いしていたこと

(誤)

dataloader作成時に、サンプルはランダム抽出される。
Trainer.fitにはサンプル抽出後のdataloaderが代入されている。
そのため、Shuffle Trueでも、1epoch目と2epoch目のサンプルの組み合わせは同じ。

(正)

dataloaderは、呼び出されるたびにサンプルをランダムに抽出する。
Trainer.fit内部では、epochが変わるたびにdataloaderを呼び出し、サンプルをランダムに抽出している。
そのため、Shuffle Trueであれば、1epoch目と2epochでもサンプルの組み合わせは異なる。

実装

  1. 事前準備

  2. DataLoader検証
    2.1 Shuffle Falseの場合
    2.2 Shuffle Trueの場合

    • dataloader定義時のみサンプルはシャッフルされる? -> 呼び出し時に実行される
    • Trainer.fit実行すると、Epoch毎にサンプルはシャッフルされる? -> される

    2.3 おまけ drop_last

 事前準備

# ライブラリの読込
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

# サンプルデータの作成

#- 入力値: 3変数、11sample
x = torch.tensor([
                              [0, 0, 0],
                              [1, 1, 1],
                              [2, 2, 2], 
                              [3, 3, 3], 
                              [4, 4, 4], 
                              [5, 5, 5],
                              [6, 6, 6],
                              [7, 7, 7], 
                              [8, 8, 8], 
                              [9, 9, 9], 
                              [10, 10, 10]])

#- 目標値: 要素数11, 1次元ベクトル
t = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

# datasetの作成
dataset = torch.utils.data.TensorDataset(x, t)

#バッチサイズ定義
batch_size  = 5

DataLoader検証

Shuffle Falseの場合

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = False, drop_last = True)
for tmp in iter(dataloader):
    print(tmp)

実行結果
0-4, 5-9と上から順番にサンプルが抽出されている。

[tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), tensor([0, 1, 2, 3, 4])]
[tensor([[5, 5, 5],
        [6, 6, 6],
        [7, 7, 7],
        [8, 8, 8],
        [9, 9, 9]]), tensor([5, 6, 7, 8, 9])]

Shuffle Trueの場合

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = True, drop_last = True)
for tmp in iter(dataloader):
    print(tmp)

実行結果
ランダムにサンプルが抽出されている。

[tensor([[3, 3, 3],
        [5, 5, 5],
        [4, 4, 4],
        [2, 2, 2],
        [7, 7, 7]]), tensor([3, 5, 4, 2, 7])]
[tensor([[6, 6, 6],
        [8, 8, 8],
        [0, 0, 0],
        [9, 9, 9],
        [1, 1, 1]]), tensor([6, 8, 0, 9, 1])]

学習の際の挙動、Epoch毎にミニバッチの内訳は異なる?

->異なる

for i in range(5):
#表示用のprint文
    print('##############')
    print('###  epoch{}  #'.format(i))
    print('##############\n')

# Trainer.fit内部の挙動
    dataloader_tmp = dataloader
    for tmp in iter(dataloader):
        print(tmp)

実行結果
Epoch毎にサンプルがランダムに抽出されている。
dataloaderがTrainer.fit内部で何度もよびだされているため。

##############
###  epoch0  #
##############

[tensor([[ 8,  8,  8],
        [ 3,  3,  3],
        [ 2,  2,  2],
        [10, 10, 10],
        [ 6,  6,  6]]), tensor([ 8,  3,  2, 10,  6])]
[tensor([[7, 7, 7],
        [9, 9, 9],
        [0, 0, 0],
        [1, 1, 1],
        [5, 5, 5]]), tensor([7, 9, 0, 1, 5])]


##############
###  epoch1  #
##############

[tensor([[1, 1, 1],
        [4, 4, 4],
        [7, 7, 7],
        [6, 6, 6],
        [0, 0, 0]]), tensor([1, 4, 7, 6, 0])]
[tensor([[10, 10, 10],
        [ 3,  3,  3],
        [ 9,  9,  9],
        [ 8,  8,  8],
        [ 5,  5,  5]]), tensor([10,  3,  9,  8,  5])]



##############
###  epoch2  #
##############

[tensor([[7, 7, 7],
        [4, 4, 4],
        [3, 3, 3],
        [9, 9, 9],
        [2, 2, 2]]), tensor([7, 4, 3, 9, 2])]
[tensor([[ 1,  1,  1],
        [ 8,  8,  8],
        [10, 10, 10],
        [ 5,  5,  5],
        [ 6,  6,  6]]), tensor([ 1,  8, 10,  5,  6])]


##############
###  epoch3  ##
##############

[tensor([[5, 5, 5],
        [2, 2, 2],
        [1, 1, 1],
        [7, 7, 7],
        [9, 9, 9]]), tensor([5, 2, 1, 7, 9])]
[tensor([[ 0,  0,  0],
        [ 4,  4,  4],
        [10, 10, 10],
        [ 8,  8,  8],
        [ 6,  6,  6]]), tensor([ 0,  4, 10,  8,  6])]

##############
###  epoch4  #
##############

[tensor([[ 3,  3,  3],
        [10, 10, 10],
        [ 0,  0,  0],
        [ 2,  2,  2],
        [ 7,  7,  7]]), tensor([ 3, 10,  0,  2,  7])]
[tensor([[8, 8, 8],
        [6, 6, 6],
        [4, 4, 4],
        [1, 1, 1],
        [9, 9, 9]]), tensor([8, 6, 4, 1, 9])]


おまけ(drop_last)

drop_last = False

# drop last False
dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = False, drop_last = False)
for tmp in iter(dataloader):
    print(tmp)

実行結果

[tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), tensor([0, 1, 2, 3, 4])]
[tensor([[5, 5, 5],
        [6, 6, 6],
        [7, 7, 7],
        [8, 8, 8],
        [9, 9, 9]]), tensor([5, 6, 7, 8, 9])]
[tensor([[10, 10, 10]]), tensor([10])]

drop_last = True

# drop last True
dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = False, drop_last = True)
for tmp in iter(dataloader):
    print(tmp)

実行結果

[tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), tensor([0, 1, 2, 3, 4])]
[tensor([[5, 5, 5],
        [6, 6, 6],
        [7, 7, 7],
        [8, 8, 8],
        [9, 9, 9]]), tensor([5, 6, 7, 8, 9])]

参照

9
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?