6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【TorchData】PyTorch 1.11と合わせて公開されたTorchDataをとりあえず使ってみる

Last updated at Posted at 2022-03-20

最近PyTorch 1.11がリリースされました.それに合わせて以下のようにトップページにデカデカと TorchDataが使えるようになった と宣伝されています.

Screenshot from 2022-03-20 10-31-43.png

TorchDataは PyTorch 1.11のリリース前からpytorch/dataで開発されていて,リリースタイミングを合わせたようです(ただしまだベータ版).
READMEには,以下のように書かれており,PyTorchのDatasetをTorchDataのDataPipesで置き換えられるような感じがします.

torchdata は、柔軟でパフォーマンスの高いデータパイプラインを簡単に構築するための、共通のモジュラーデータローディングプリミティブのライブラリです。 PyTorchのDataLoaderですぐに使える、DataPipesと呼ばれるIterableスタイルやMapスタイルのビルディングブロックを提供することを目的としています。

そこで本記事では,上記github及び公式ドキュメントを参考に DataPipes を使ってみようと思います.

DataPipesとは

DataPipesには以下の2種類があります.

  • Iterable-style DataPipes
  • Map-style DataPipes

前者は,PyTorchの IterableDataset のサブクラスで,データのランダムな読み込みが高コストな場合やバッチサイズが取得するデータに依存している場合に適しています.例えば,データベースやリモートサーバーから読み込んだデータのストリームや,リアルタイムで生成されたログを返すことができ,これは PyTorch の IterableDataset のアップデート版です.

後者は,一般的に使用されているPyTorchのDatasetと同様で, __getitem____len__ を実装し,インデックスやキーからデータサンプルへのマップを表現します.

つまり,DataPipeは,PyTorch Datasetを構成的に使うために,名前を変えただけのものです.DataPipeはPythonのデータ構造に対するアクセス関数、IterDataPipesなら __iter__ , MapDataPipesなら __getitem__ を受け取り、新しいアクセス関数にちょっとした変換を施して返します。

例えば、この JsonParser はファイル名と生のストリームに対する IterDataPipe を受け取り、ファイル名とデシリアライズされたデータに対する新しいイテレータを生成しています。

import json

class JsonParserIterDataPipe(IterDataPipe):
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs

    def __iter__(self):
        for file_name, stream in self.source_datapipe:
            data = stream.read()
            yield file_name, json.loads(data)

    def __len__(self):
        return len(self.source_datapipe)

チュートリアル

このチュートリアルでは,以下の3ステップでCSVファイルからデータを読み込みます.

  • ディレクトリ内のCSVファイルを列挙する
  • CSVファイルを読み込む
  • CSVファイルをパースして,行ごとにyieldする

また,torchdataには以下のbuild-inのDataPipesがあります.

  • FileLister | ディレクトリ内のファイルをリストする
  • Filter | 指定された関数に基づいてDataPipeの要素をフィルターする
  • FileOpener | ファイルパスからオープンしたファイルストリームを返す
  • CSVParser | ファイルストリームからCSVの内容をパースし、パースされた行を1つずつ返す

上記を用いると, CSVParser は以下のようになります.

@functional_datapipe("parse_csv")
class CSVParserIterDataPipe(IterDataPipe):
    def __init__(self, dp, **fmtparams) -> None:
        self.dp = dp
        self.fmtparams = fmtparams

    def __iter__(self) -> Iterator[Union[Str_Or_Bytes, Tuple[str, Str_Or_Bytes]]]:
        for path, file in self.source_datapipe:
            stream = self._helper.skip_lines(file)
            stream = self._helper.strip_newline(stream)
            stream = self._helper.decode(stream)
            yield from self._helper.return_path(stream, path=path)  # Returns 1 line at a time as List[str or bytes]

そして,DataPipesは関数型を使って呼び出すことができ,パイプラインは以下のように組み立てることができます.

import torchdata.datapipes as dp

FOLDER = 'path/2/csv/folder'
datapipe = dp.iter.FileLister([FOLDER]).filter(filter_fn=lambda filename: filename.endswith('.csv'))
datapipe = dp.iter.FileOpener(datapipe, mode='rt')
datapipe = datapipe.parse_csv(delimiter=',')  # @functional_datapipe("parse_csv")で登録済み

for d in datapipe: # Iterating through the data
     pass

いざ実践

トイデータのCSVに対してPyTorchのDatasetの代わりにtorchdataのDataPipesが使えるかを検証してみましょう.
なお,torch==1.11.0かつtorchdata==0.3.0の環境にしないと実行はできませんでしたのでご注意下さい.

import csv
import random
from torch.utils.data import DataLoader
import numpy as np
import torchdata.datapipes as dp

def generate_csv(file_label, num_rows: int = 5000, num_features: int = 20) -> None:
    fieldnames = ['label'] + [f'c{i}' for i in range(num_features)]
    writer = csv.DictWriter(open(f"sample_data{file_label}.csv", "w"), fieldnames=fieldnames)
    writer.writerow({col: col for col in fieldnames})  # writing the header row
    for i in range(num_rows):
        row_data = {col: random.random() for col in fieldnames}
        row_data['label'] = random.randint(0, 9)
        writer.writerow(row_data)

def build_datapipes(root_dir="."):
    datapipe = dp.iter.FileLister(root_dir)
    datapipe = datapipe.filter(filter_fn=lambda filename: "sample_data" in filename and filename.endswith(".csv"))
    datapipe = dp.iter.FileOpener(datapipe, mode='rt')
    datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
    datapipe = datapipe.map(lambda row: {"label": np.array(row[0], np.int32),
                                         "data": np.array(row[1:], dtype=np.float64)})
    return datapipe

if __name__ == '__main__':
    num_files_to_generate = 3
    for i in range(num_files_to_generate):
        generate_csv(file_label=i)
    datapipe = build_datapipes()
    dl = DataLoader(dataset=datapipe, batch_size=50, shuffle=True)
    first = next(iter(dl))
    labels, features = first['label'], first['data']
    print(f"Labels batch shape: {labels.size()}")
    print(f"Feature batch shape: {features.size()}")

出力

Labels batch shape: 50
Feature batch shape: torch.Size([50, 20])

無事にDatasetと同様に使うことができていますね.一見めんどくさいDatasetにも見えますが,DataPipeをつかうことで,各プロジェクトごとにサイロ化した前処理を共通化できそうで良さそうです.無論,研究用としてはこれからもDataset一択でしょうが…

▼以上は公式チュートリアルからの抜粋でした.さらに興味がある方は公式Docをご覧下さい.
https://pytorch.org/data/main/tutorial.html

6
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?