22
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【基礎】PyTorchのDataLoader/Datasetの使い方【MNIST】

Last updated at Posted at 2023-03-27

きっかけ

PyTorchを使ってみて最初によくわからなくなったのが

  • DataLoader
  • Dataset

あたりの使い方だった。
サンプルコードでなんとなく動かすことはできたけど、こいつらはいったい何なのか
調べながらまとめてみる。

ニューラルネットとミニバッチ

ニューラルネットの学習においてはミニバッチ学習がよく用いられる。

ニューラルネット学習においては入力データが膨大になり全部はメモリに載らないことが多い。
そのため指定したバッチサイズごと(ミニバッチ)に区切ってメモリ上にロードし、順番に学習していく手法がよく用いられる。
これをミニバッチ学習という。
機械学習の工程はまず最初にデータを読み込む必要があるわけだが、DataLoader/Datasetは最初の「前処理、データのロード」に関連する話だ。
ミニバッチ用の読み込みの仕組みをゼロから自分で実装するのは大変なので、みんなこの機能をありがたく使っているというわけだ。

MNISTのデータ読み込み

雰囲気を掴むために手書き文字の分類タスクで有名なMNISTのデータを読み込んでみよう。

Datasetの準備

PyTorchでミニバッチ学習する際はやや特殊な型変換が必要となる。
まずはPyTorch向けのライブラリであるtorchvisionからMNISTのデータを取得する。

import torch
import torchvision
import torchvision.transforms as transforms

mnist = torchvision.datasets.MNIST(root='./data',
                                        train=True,
                                        transform=transforms.ToTensor(),
                                        download=True)
print(mnist)
出力
Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

引数はそれぞれ

  • root: データを取得するpath
  • train: Trueの場合はtraining data、Falseの場合はtest data
  • transform: 前処理の設定
  • download: Trueにするとrootで指定したpathに保存

となる。
ちなみにこのMNISTの型はtorchvision.datasets.mnist.MNISTとなる。
これでデータセットの準備ができた。

DataLoaderを使う

次にDataLoaderを使って先ほどのデータセットを読み込む。

import torch
dl = torch.utils.data.DataLoader(dataset=mnist, 
                                         batch_size=10000, 
                                         shuffle=True, 
                                         num_workers=2)

引数はそれぞれ

  • datasetで読み込むデータセットを指定
  • batch_sizeでミニバッチ学習におけるバッチサイズを指定
  • shuffle=Trueにするとランダムに抽出してくれる
  • num_workersはいくつに並列化して計算するか

を表している。
この変数dlはiterableな形、つまりfor文とかで取り出せる形になっているので、試しに取り出してみよう。

X_list, y_list = [], []
for i , (X, y) in enumerate(dl):
    print(i)
    X_list.append(X)
    y_list.append(y)
print('len:', len(X_list), len(y_list))
出力
0
1
2
3
4
5
len: 6 6

今回は60000あるデータを10000のバッチに区切って読み込んでいるため、6つの要素がlistに格納されている。
このうち一つを取り出して見てみよう。
なお、各要素はtorch.Tensorになっている。
Tensor型はPyTorchで使う特殊な型で、使い勝手はndarrayと似ているがGPUの計算にも対応している。
for文の中でX.to(device)のように指定すれば形式を変更できる。
デフォルトはCPUだが、例えばX.to('cuda')とすればGPUの計算が可能だ。[1]

print(X.size())
print(y.size())
出力
torch.Size([10000, 1, 28, 28])
torch.Size([10000])

imageにはミニバッチサイズ、チャネル、縦ピクセル数、横ピクセル数で値が格納されている。
チャネルは今回は白黒なので1になっている。
labelには0-9のどの文字が正解なのか、の情報が格納されている。
試しに可視化してみる。

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.imshow(X[0][0])
ax.axis('off')
ax.set_title(f'images, label={y[0]}', fontsize=20)
plt.show()

ダウンロード.png

こんな感じ。
ここでは0の画像を一枚取り出して見ている。

Numpy/Pandasでの読み込み

なんとなくイメージは掴めただろうか。
さて、上ではtorchvisionからMNISTのデータを引っ張ってきたが、もう少し実際の状況に合わせて考えてみよう。
inputとなるデータがnumpyの場合、どのようにdatasetを定義してDataLoaderに渡せばいいだろうか。
次はsklearn経由でMNISTのデータをロードしてみよう。

from sklearn.datasets import load_digits
digits=load_digits()
X = digits.data
y = digits.target
print(X.shape, y.shape)
出力
(1797, 64) (1797,)

ここでは1797個のサンプルが含まれており、それぞれndarray型である。
yは先ほどの例のlabelにあたり、Xはimageにあたる部分だ。
先ほどの例と比べると解像度はやや下がり、8*8=64ピクセルとなっている。

さて、このままだとDataLoaderで読み込めないのでTensor型に変換する必要がある。

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.int64) 
mnist = torch.utils.data.TensorDataset(X, y)

あとは先ほどと同じようにDataLoaderを使ってiterableに取り出すことができる。
正解ラベルであるyは整数型、featureとなるXはfloat型になっていることに注意。
Pandasでの読み込みについてはここでは省略するので、[4]などを参考にしてほしい。

カスタム関数の作成

実際にKaggleや実務などでDatasetを用意する場合はカスタム関数を使うことが多い。
カスタム関数はDatasetを継承してクラスを作成することで使うことができる。

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, df, features, labels):
        self._features_values = df[features].values
        self._labels = df[labels].values

    def __len__(self):
        return len(self._features_values)

    def __getitem__(self, idx):
        features_x = torch.FloatTensor(self._features_values[idx])
        labels = torch.LongTensor(self._labels[idx])
        return features_x, labels

ここではtorch.utils.data.Datasetを継承して新しくMyDatasetというクラスを定義した。
inputのdfはpandasのdataframeを前提としているため、__init__の中で内部変数に格納している。
__len__はlen()を呼び出す時に必要になり、__getitem__は要素を参照する時に必要となるため、この二つの
関数は必須で書く必要がある。
今回は__getitem__の中でTensor型に変更して取得する処理を入れている。
もちろんカスタム関数なので、自由にいろんな処理を加えることが可能だ。
クラスから実際にインスタンスを作成する処理については[5]などを参照。

まとめ

  • PyTorchの学習においては
    • Datasetの作成 -> DataLoaderにロードの手順でミニバッチ学習を行う
  • Dataset, DataLoaderには特殊な型を指定する必要がある
  • 実際に使うときはカスタム関数を使うことが多い

参考資料

カジュアルな記事ばっかり載せたけど、本家サイトみるのが一番正確

22
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?