きっかけ
PyTorchを使ってみて最初によくわからなくなったのが
DataLoader
Dataset
あたりの使い方だった。
サンプルコードでなんとなく動かすことはできたけど、こいつらはいったい何なのか。
調べながらまとめてみる。
ニューラルネットとミニバッチ
ニューラルネットの学習においてはミニバッチ学習がよく用いられる。
ニューラルネット学習においては入力データが膨大になり全部はメモリに載らないことが多い。
そのため指定したバッチサイズごと(ミニバッチ)に区切ってメモリ上にロードし、順番に学習していく手法がよく用いられる。
これをミニバッチ学習という。
機械学習の工程はまず最初にデータを読み込む必要があるわけだが、DataLoader/Datasetは最初の「前処理、データのロード」に関連する話だ。
ミニバッチ用の読み込みの仕組みをゼロから自分で実装するのは大変なので、みんなこの機能をありがたく使っているというわけだ。
MNISTのデータ読み込み
雰囲気を掴むために手書き文字の分類タスクで有名なMNISTのデータを読み込んでみよう。
Datasetの準備
PyTorchでミニバッチ学習する際はやや特殊な型変換が必要となる。
まずはPyTorch向けのライブラリであるtorchvisionからMNISTのデータを取得する。
import torch
import torchvision
import torchvision.transforms as transforms
mnist = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
print(mnist)
Dataset MNIST
Number of datapoints: 60000
Root location: ./data
Split: Train
StandardTransform
Transform: ToTensor()
引数はそれぞれ
-
root
: データを取得するpath -
train
: Trueの場合はtraining data、Falseの場合はtest data -
transform
: 前処理の設定 -
download
: Trueにするとrootで指定したpathに保存
となる。
ちなみにこのMNISTの型はtorchvision.datasets.mnist.MNIST
となる。
これでデータセットの準備ができた。
DataLoaderを使う
次にDataLoaderを使って先ほどのデータセットを読み込む。
import torch
dl = torch.utils.data.DataLoader(dataset=mnist,
batch_size=10000,
shuffle=True,
num_workers=2)
引数はそれぞれ
-
dataset
で読み込むデータセットを指定 -
batch_size
でミニバッチ学習におけるバッチサイズを指定 -
shuffle=True
にするとランダムに抽出してくれる -
num_workers
はいくつに並列化して計算するか
を表している。
この変数dl
はiterableな形、つまりfor文とかで取り出せる形になっているので、試しに取り出してみよう。
X_list, y_list = [], []
for i , (X, y) in enumerate(dl):
print(i)
X_list.append(X)
y_list.append(y)
print('len:', len(X_list), len(y_list))
0
1
2
3
4
5
len: 6 6
今回は60000あるデータを10000のバッチに区切って読み込んでいるため、6つの要素がlistに格納されている。
このうち一つを取り出して見てみよう。
なお、各要素はtorch.Tensor
になっている。
Tensor型はPyTorchで使う特殊な型で、使い勝手はndarrayと似ているがGPUの計算にも対応している。
for文の中でX.to(device)
のように指定すれば形式を変更できる。
デフォルトはCPUだが、例えばX.to('cuda')
とすればGPUの計算が可能だ。[1]
print(X.size())
print(y.size())
torch.Size([10000, 1, 28, 28])
torch.Size([10000])
imageにはミニバッチサイズ、チャネル、縦ピクセル数、横ピクセル数で値が格納されている。
チャネルは今回は白黒なので1になっている。
labelには0-9のどの文字が正解なのか、の情報が格納されている。
試しに可視化してみる。
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.imshow(X[0][0])
ax.axis('off')
ax.set_title(f'images, label={y[0]}', fontsize=20)
plt.show()
こんな感じ。
ここでは0の画像を一枚取り出して見ている。
Numpy/Pandasでの読み込み
なんとなくイメージは掴めただろうか。
さて、上ではtorchvisionからMNISTのデータを引っ張ってきたが、もう少し実際の状況に合わせて考えてみよう。
inputとなるデータがnumpyの場合、どのようにdatasetを定義してDataLoaderに渡せばいいだろうか。
次はsklearn経由でMNISTのデータをロードしてみよう。
from sklearn.datasets import load_digits
digits=load_digits()
X = digits.data
y = digits.target
print(X.shape, y.shape)
(1797, 64) (1797,)
ここでは1797個のサンプルが含まれており、それぞれndarray
型である。
yは先ほどの例のlabelにあたり、Xはimageにあたる部分だ。
先ほどの例と比べると解像度はやや下がり、8*8=64ピクセルとなっている。
さて、このままだとDataLoaderで読み込めないのでTensor型に変換する必要がある。
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.int64)
mnist = torch.utils.data.TensorDataset(X, y)
あとは先ほどと同じようにDataLoaderを使ってiterableに取り出すことができる。
正解ラベルであるyは整数型、featureとなるXはfloat型になっていることに注意。
Pandasでの読み込みについてはここでは省略するので、[4]などを参考にしてほしい。
カスタム関数の作成
実際にKaggleや実務などでDatasetを用意する場合はカスタム関数を使うことが多い。
カスタム関数はDatasetを継承してクラスを作成することで使うことができる。
class MyDataset(torch.utils.data.Dataset):
def __init__(self, df, features, labels):
self._features_values = df[features].values
self._labels = df[labels].values
def __len__(self):
return len(self._features_values)
def __getitem__(self, idx):
features_x = torch.FloatTensor(self._features_values[idx])
labels = torch.LongTensor(self._labels[idx])
return features_x, labels
ここではtorch.utils.data.Dataset
を継承して新しくMyDatasetというクラスを定義した。
inputのdfはpandasのdataframeを前提としているため、__init__
の中で内部変数に格納している。
__len__
はlen()を呼び出す時に必要になり、__getitem__
は要素を参照する時に必要となるため、この二つの
関数は必須で書く必要がある。
今回は__getitem__
の中でTensor型に変更して取得する処理を入れている。
もちろんカスタム関数なので、自由にいろんな処理を加えることが可能だ。
クラスから実際にインスタンスを作成する処理については[5]などを参照。
まとめ
- PyTorchの学習においては
- Datasetの作成 -> DataLoaderにロードの手順でミニバッチ学習を行う
- Dataset, DataLoaderには特殊な型を指定する必要がある
- 実際に使うときはカスタム関数を使うことが多い
参考資料
- [1] https://note.nkmk.me/python-pytorch-device-to-cuda-cpu/
- [2] https://qiita.com/mathlive/items/241bfb42d852bb801b96
- Tensor型の説明がされているQiitaの記事
- [3]https://qiita.com/takurooo/items/e4c91c5d78059f92e76d
- transforms/Dataset/DataLoaderの説明がされている記事
- [4]https://dreamer-uma.com/pytorch-dataset/
- datasetの説明がされているブログ
- [5] https://dreamer-uma.com/pytorch-dataloader/
- dataloaderの説明がされているブログ
- [6]Kaggleに挑む深層学習プログラミングの極意(書籍)
カジュアルな記事ばっかり載せたけど、本家サイトみるのが一番正確