最近PyTorch 1.11がリリースされました.それに合わせて以下のようにトップページにデカデカと TorchDataが使えるようになった と宣伝されています.
TorchDataは PyTorch 1.11のリリース前からpytorch/dataで開発されていて,リリースタイミングを合わせたようです(ただしまだベータ版).
READMEには,以下のように書かれており,PyTorchのDatasetをTorchDataのDataPipesで置き換えられるような感じがします.
torchdata は、柔軟でパフォーマンスの高いデータパイプラインを簡単に構築するための、共通のモジュラーデータローディングプリミティブのライブラリです。 PyTorchのDataLoaderですぐに使える、DataPipesと呼ばれるIterableスタイルやMapスタイルのビルディングブロックを提供することを目的としています。
そこで本記事では,上記github及び公式ドキュメントを参考に DataPipes を使ってみようと思います.
DataPipesとは
DataPipesには以下の2種類があります.
- Iterable-style DataPipes
- Map-style DataPipes
前者は,PyTorchの IterableDataset のサブクラスで,データのランダムな読み込みが高コストな場合やバッチサイズが取得するデータに依存している場合に適しています.例えば,データベースやリモートサーバーから読み込んだデータのストリームや,リアルタイムで生成されたログを返すことができ,これは PyTorch の IterableDataset のアップデート版です.
後者は,一般的に使用されているPyTorchのDatasetと同様で, __getitem__
と __len__
を実装し,インデックスやキーからデータサンプルへのマップを表現します.
つまり,DataPipeは,PyTorch Datasetを構成的に使うために,名前を変えただけのものです.DataPipeはPythonのデータ構造に対するアクセス関数、IterDataPipesなら __iter__
, MapDataPipesなら __getitem__
を受け取り、新しいアクセス関数にちょっとした変換を施して返します。
例えば、この JsonParser はファイル名と生のストリームに対する IterDataPipe を受け取り、ファイル名とデシリアライズされたデータに対する新しいイテレータを生成しています。
import json
class JsonParserIterDataPipe(IterDataPipe):
def __init__(self, source_datapipe, **kwargs) -> None:
self.source_datapipe = source_datapipe
self.kwargs = kwargs
def __iter__(self):
for file_name, stream in self.source_datapipe:
data = stream.read()
yield file_name, json.loads(data)
def __len__(self):
return len(self.source_datapipe)
チュートリアル
このチュートリアルでは,以下の3ステップでCSVファイルからデータを読み込みます.
- ディレクトリ内のCSVファイルを列挙する
- CSVファイルを読み込む
- CSVファイルをパースして,行ごとにyieldする
また,torchdataには以下のbuild-inのDataPipesがあります.
-
FileLister
| ディレクトリ内のファイルをリストする -
Filter
| 指定された関数に基づいてDataPipeの要素をフィルターする -
FileOpener
| ファイルパスからオープンしたファイルストリームを返す -
CSVParser
| ファイルストリームからCSVの内容をパースし、パースされた行を1つずつ返す
上記を用いると, CSVParser
は以下のようになります.
@functional_datapipe("parse_csv")
class CSVParserIterDataPipe(IterDataPipe):
def __init__(self, dp, **fmtparams) -> None:
self.dp = dp
self.fmtparams = fmtparams
def __iter__(self) -> Iterator[Union[Str_Or_Bytes, Tuple[str, Str_Or_Bytes]]]:
for path, file in self.source_datapipe:
stream = self._helper.skip_lines(file)
stream = self._helper.strip_newline(stream)
stream = self._helper.decode(stream)
yield from self._helper.return_path(stream, path=path) # Returns 1 line at a time as List[str or bytes]
そして,DataPipesは関数型を使って呼び出すことができ,パイプラインは以下のように組み立てることができます.
import torchdata.datapipes as dp
FOLDER = 'path/2/csv/folder'
datapipe = dp.iter.FileLister([FOLDER]).filter(filter_fn=lambda filename: filename.endswith('.csv'))
datapipe = dp.iter.FileOpener(datapipe, mode='rt')
datapipe = datapipe.parse_csv(delimiter=',') # @functional_datapipe("parse_csv")で登録済み
for d in datapipe: # Iterating through the data
pass
いざ実践
トイデータのCSVに対してPyTorchのDatasetの代わりにtorchdataのDataPipesが使えるかを検証してみましょう.
なお,torch==1.11.0かつtorchdata==0.3.0の環境にしないと実行はできませんでしたのでご注意下さい.
import csv
import random
from torch.utils.data import DataLoader
import numpy as np
import torchdata.datapipes as dp
def generate_csv(file_label, num_rows: int = 5000, num_features: int = 20) -> None:
fieldnames = ['label'] + [f'c{i}' for i in range(num_features)]
writer = csv.DictWriter(open(f"sample_data{file_label}.csv", "w"), fieldnames=fieldnames)
writer.writerow({col: col for col in fieldnames}) # writing the header row
for i in range(num_rows):
row_data = {col: random.random() for col in fieldnames}
row_data['label'] = random.randint(0, 9)
writer.writerow(row_data)
def build_datapipes(root_dir="."):
datapipe = dp.iter.FileLister(root_dir)
datapipe = datapipe.filter(filter_fn=lambda filename: "sample_data" in filename and filename.endswith(".csv"))
datapipe = dp.iter.FileOpener(datapipe, mode='rt')
datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
datapipe = datapipe.map(lambda row: {"label": np.array(row[0], np.int32),
"data": np.array(row[1:], dtype=np.float64)})
return datapipe
if __name__ == '__main__':
num_files_to_generate = 3
for i in range(num_files_to_generate):
generate_csv(file_label=i)
datapipe = build_datapipes()
dl = DataLoader(dataset=datapipe, batch_size=50, shuffle=True)
first = next(iter(dl))
labels, features = first['label'], first['data']
print(f"Labels batch shape: {labels.size()}")
print(f"Feature batch shape: {features.size()}")
出力
Labels batch shape: 50
Feature batch shape: torch.Size([50, 20])
無事にDatasetと同様に使うことができていますね.一見めんどくさいDatasetにも見えますが,DataPipeをつかうことで,各プロジェクトごとにサイロ化した前処理を共通化できそうで良さそうです.無論,研究用としてはこれからもDataset一択でしょうが…
▼以上は公式チュートリアルからの抜粋でした.さらに興味がある方は公式Docをご覧下さい.
https://pytorch.org/data/main/tutorial.html