LoginSignup
0
0

More than 1 year has passed since last update.

webdatasetの使い方上級編1:multiprocessingでshard作成

Posted at

これはwebdatasetの使い方の続編です.

webdatasetは複数のshardと呼ばれるtarファイルにデータを固めて読み込むライブラリです.shardの作り方は上の記事に書いたとおりですが,大量のデータをshardに固める場合に並列処理をしたくなります.

ということで,この記事ではshardをmultiprocessingで作成する方法を紹介します.概要は以下の通り.

  • 画像がカテゴリ名のサブフォルダ以下に保存されているとします.
  • 複数のworkerをmultiprocessingで作成し,shardを作成します.
  • 画像の読み込み処理などは別々のworkerで行いますが,shardの書き込みは単一のwriterインスタンスで扱います.
    • こうするとshard管理が一箇所で行えます.(ファイル名インクリメントなど)
    • まあ別々のファイル名にしても問題ないので,別々のworkerで別々のwriterオブジェクトを作成してもいいのですが.
全体のコードは長いので折りたたんでおきます.
全体コードのmake_jpeg_shards.py
from functools import partial
from pathlib import Path
from tqdm import tqdm
import argparse
import json
import os
import random
from webdataset import ShardWriter
from multiprocessing import Pool, current_process
from multiprocessing.managers import SyncManager
from PIL import Image


class MyShardWriter(ShardWriter):

    def __init__(self, pattern, maxcount=100000, maxsize=3e9, post=None, start_shard=0, **kw):
        super().__init__(pattern, maxcount, maxsize, post, start_shard)
        self.verbose = False

    def get_shards(self):
        return self.shard

    def get_count(self):
        return self.count if self.count < self.maxcount else 0

    def get_total(self):
        return self.total


class MyManager(SyncManager):
    pass


def worker(file_path, lock, pbar, sink, class_to_idx):

    # 'ForkPoolWorker-21' --> 21
    # https://docs.python.org/ja/3/library/multiprocessing.html#multiprocessing.Process.name
    worker_id = int(current_process().name.split('-')[-1])

    #
    # process
    #

    # when file_path == 'dataset/cats_dogs/PetImages/Dog/10247.jpg'
    category_name = file_path.parent.name  # 'Dog': str
    label = class_to_idx[category_name]  # 1: int
    key_str = category_name + '/' + file_path.stem  # 'Dog/10247': str

    try:
        Image.open(str(file_path))  # check if corrupted
    except Exception:
        return  # skip when error

    with open(str(file_path), 'rb') as raw_bytes:
        buffer = raw_bytes.read()

    #
    # write
    #

    with lock:

        sample_dic = {
            '__key__': key_str,
            'json': json.dumps({
                'write worker id': current_process().name,
                'count': sink.get_count(),
                'category': category_name,
                'label': label,
            }),
            'jpg': buffer
        }
        sink.write(sample_dic)

        pbar.update(1)
        pbar.set_postfix_str(
            f'shard {sink.get_shards()} '
            f'worker {worker_id}'
        )


def make_shards(args):

    file_paths = [
        path for path in Path(args.data_path).glob('**/*')
        if not path.is_dir()
    ]
    if args.shuffle:
        random.shuffle(file_paths)
    n_samples = len(file_paths)

    # https://github.com/pytorch/vision/blob/a8bde78130fd8c956780d85693d0f51912013732/torchvision/datasets/folder.py#L36
    class_list = sorted(
        entry.name for entry in os.scandir(args.data_path)
        if entry.is_dir())
    class_to_idx = {cls_name: i for i, cls_name in enumerate(class_list)}

    shard_dir_path = Path(args.shard_path)
    shard_dir_path.mkdir(exist_ok=True)
    shard_filename = str(shard_dir_path / f'{args.shard_prefix}-%05d.tar')

    # https://qiita.com/tttamaki/items/96b65e6555f9d255ffd9
    MyManager.register('Tqdm', tqdm)
    MyManager.register('Sink', MyShardWriter)

    with MyManager() as manager:

        #
        # prepare manager objects
        #

        lock = manager.Lock()
        pbar = manager.Tqdm(
            total=n_samples,
            position=0,
        )
        pbar.set_description('Main process')
        sink = manager.Sink(
            pattern=shard_filename,
            maxsize=args.max_size,
            maxcount=args.max_count,
        )

        #
        # create worker pool
        #

        worker_with_args = partial(
            worker, lock=lock, pbar=pbar, sink=sink,
            class_to_idx=class_to_idx
        )
        with Pool(processes=args.num_workers) as pool:
            # https://stackoverflow.com/questions/26520781/multiprocessing-pool-whats-the-difference-between-map-async-and-imap
            # for _ in pool.imap_unordered(
            #         worker_with_args,
            #         file_paths,
            #         chunksize=n_samples // args.num_workers
            # ):
            #     pass
            pool.map(
                worker_with_args,
                file_paths,
                chunksize=n_samples // args.num_workers
            )

        #
        # write json of dataset size
        #

        dataset_size_filename = str(
            shard_dir_path / f'{args.shard_prefix}-dataset-size.json')
        with open(dataset_size_filename, 'w') as fp:
            json.dump({
                'dataset size': sink.get_total(),
                'num_classes': len(class_to_idx),
                'class_to_idx': class_to_idx,
            }, fp)

        sink.close()
        pbar.close()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('-d', '--data_path', action='store',
                        help='Path to the dataset dir with category subdirs.')
    parser.add_argument('-s', '--shard_path', action='store',
                        default='./test_shards/',
                        help='Path to the dir to store shard tar files.')
    parser.add_argument('-p', '--shard_prefix', action='store',
                        default='test',
                        help='Prefix of shard tar files.')
    parser.add_argument('--shuffle', dest='shuffle', action='store_true',
                        help='use shuffle')
    parser.add_argument('--no_shuffle', dest='shuffle', action='store_false',
                        help='do not use shuffle')
    parser.set_defaults(shuffle=True)
    parser.add_argument('--max_size', type=float, default=100000,
                        help='Max size [B] of each shard tar file. '
                        'default to 100000 bytes.')
    parser.add_argument('--max_count', type=int, default=100,
                        help='Max number of entries in each shard tar file. '
                        'default to 100.')
    parser.add_argument('-w', '--num_workers', type=int, default=4,
                        help='Number of workers. '
                        'default to 4.')
    args = parser.parse_args()

    make_shards(args)

ではコードを解説していきます.

まずはmain()

argparseの引数処理です.一般的な引数を設定してます.

  • -d:データセットのディレクトリ.jpeg画像ファイルがカテゴリ名毎のサブフォルダに保存されている.
  • -s:shard保存先.
main部分
if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('-d', '--data_path', action='store',
                        help='Path to the dataset dir with category subdirs.')
    parser.add_argument('-s', '--shard_path', action='store',
                        default='./test_shards/',
                        help='Path to the dir to store shard tar files.')
    parser.add_argument('-p', '--shard_prefix', action='store',
                        default='test',
                        help='Prefix of shard tar files.')
    parser.add_argument('--shuffle', dest='shuffle', action='store_true',
                        help='use shuffle')
    parser.add_argument('--no_shuffle', dest='shuffle', action='store_false',
                        help='do not use shuffle')
    parser.set_defaults(shuffle=True)
    parser.add_argument('--max_size', type=float, default=100000,
                        help='Max size [B] of each shard tar file. '
                        'default to 100000 bytes.')
    parser.add_argument('--max_count', type=int, default=100,
                        help='Max number of entries in each shard tar file. '
                        'default to 100.')
    parser.add_argument('-w', '--num_workers', type=int, default=4,
                        help='Number of workers. '
                        'default to 4.')
    args = parser.parse_args()

    make_shards(args)

全ワーカーと共有するための準備

今回は,メインプロセスとワーカープロセスで

  • shardの書き込みを管理するwebdataset.ShardWriter
  • 進捗を表示するtqdm

のインスタンスをmultiprocessing.managers.SyncManagerを使って共有します.そのためにまず

  • ShardWriterの派生クラスを定義.SyncManagerで共有するとattributeにアクセスできないようなので,shard番号やtotalサイズなどを返すだけのメソッドを追加しただけです.
  • SyncManagerの派生クラスを定義.これを使ってワーカー間でオブジェクトを共有します.これについては以下の記事を参照.

ヘルパークラスの定義

class MyShardWriter(ShardWriter):

    def __init__(self, pattern, maxcount=100000, maxsize=3e9, post=None, start_shard=0, **kw):
        super().__init__(pattern, maxcount, maxsize, post, start_shard)
        self.verbose = False

    def get_shards(self):
        return self.shard

    def get_count(self):
        return self.count if self.count < self.maxcount else 0

    def get_total(self):
        return self.total


class MyManager(SyncManager):
    pass

メインプロセス:shard情報・共有オブジェクトの準備とworkerの起動

それではデータセットを書き込む準備をします.

jpegファイル名の取得
def make_shards(args):

    file_paths = [
        path for path in Path(args.data_path).glob('**/*')
        if not path.is_dir()
    ]
    if args.shuffle:
        random.shuffle(file_paths)
    n_samples = len(file_paths)
  • args.data_pathで与えられたディレクトリ以下のファイル(サブディレクトリを除く)をすべて列挙してfile_pathsに追加します.
  • globで取得されるファイルの順序は任意なので,shuffleするならこの時点でシャッフルしておきます.
クラス番号の取得
    # https://github.com/pytorch/vision/blob/a8bde78130fd8c956780d85693d0f51912013732/torchvision/datasets/folder.py#L36
    class_list = sorted(
        entry.name for entry in os.scandir(args.data_path)
        if entry.is_dir())
    class_to_idx = {cls_name: i for i, cls_name in enumerate(class_list)}
  • args.data_pathで与えられたディレクトリ以下のサブディレクトリはカテゴリ名であると仮定しているので,これらを列挙してenumerateでカテゴリ番号に変換します.
  • 変換テーブルがclass_to_idx
shard保存先の準備
    shard_dir_path = Path(args.shard_path)
    shard_dir_path.mkdir(exist_ok=True)
    shard_filename = str(shard_dir_path / f'{args.shard_prefix}-%05d.tar')
  • shard保存先のディレクトリを設定します(なければ作成)
  • shardであるtarファイル名を設定します
共有オブジェクトの登録
    # https://qiita.com/tttamaki/items/96b65e6555f9d255ffd9
    MyManager.register('Tqdm', tqdm)
    MyManager.register('Sink', MyShardWriter)

    with MyManager() as manager:
  • multiprocessing.managers.SyncManagerで共有するクラス名を登録します.ここでは
    • 複数ワーカーの進捗を一つのtqdmプログレスバーで表すために,tqdmを共有登録します
    • 複数ワーカーが一つのsharedWriterで書き込むため,MyShardWriterを共有登録します
  • with文以降でmanagerオブジェクトが起動します.
共有インスタンスの生成
        lock = manager.Lock()
        pbar = manager.Tqdm(
            total=n_samples,
            position=0,
        )
        pbar.set_description('Main process')
        sink = manager.Sink(
            pattern=shard_filename,
            maxsize=args.max_size,
            maxcount=args.max_count,
        )

実際に共有するインスタンスを生成します.

  • lock:ロックオブジェクト.一つのインスタンスに複数ワーカーがアクセスするために必要.
  • pbar:tqdmのプログレスバー.workerがjpegファイルを読み込んで,shard writerで書き込んだらupdateされる.
  • sink:shard writerインスタンス.

ちなみにtqdmにはmultiprocessingと組み合わせるためのtqdm.set_lock()があるのですが,これは複数のtqdmが一つのlockを共有する(つまり複数のプログレスバーが同時に更新されないようにする)もので,ここでやっている一つのtqdmを複数プロセスが更新するというのとは異なります.

        worker_with_args = partial(
            worker, lock=lock, pbar=pbar, sink=sink,
            class_to_idx=class_to_idx
        )

mapでworkerに渡せる引数は1つなのですが,それはメインプロセスからワーカーへ送るjpegファイル名に使ってしまいます.そうすると追加の引数として共有インスタンスを渡せません.

そこでfunctools.partialを使って,引数付きの関数オブジェクトを作ってしまいます.こうすると,引数を固定した関数オブジェクトをmapに渡せばOKです.

        with Pool(processes=args.num_workers) as pool:
            # for _ in pool.imap_unordered(
            #         worker_with_args,
            #         file_paths,
            #         chunksize=n_samples // args.num_workers
            # ):
            #     pass
            pool.map(
                worker_with_args,
                file_paths,
                chunksize=n_samples // args.num_workers
            )

では実際にワーカーを生成します.

  • Pool()でワーカープールpoolを作成します.
  • pool.map()でワーカーを起動します.
    • 一般的にはpool.imap_unorderedを使うことが多いようなのですが,今回はmapを使います.その理由は
      • ワーカーに渡すのはファイル名リストfile_pathsなので遅延評価は必要ない(つまりimapではなくてmapでOK)
      • ワーカーはshardに書き込むだけなのでメインプロセスに返ってくるものはないし,メインがブロックされてもよい(つまりmap_asyncimap_unorderedではなくて単に同期mapでOK)
    • worker_with_argsが並列に処理される関数
    • file_pathsの各要素がワーカーに渡される
    • chunksizeはワーカーに渡される要素チャンクの長さ.ここでは全データ数をワーカー数で割った数(商)にしています.
      • 10ワーカー,100サンプルならchunksize=10で,各ワーカーが10サンプルずつで終了.
      • 10ワーカー,103サンプルならchunksize=19で,各ワーカーに10サンプルずつ渡されて,残った3サンプルは処理が終わったワーカーに追加で渡される.
後処理.
        dataset_size_filename = str(
            shard_dir_path / f'{args.shard_prefix}-dataset-size.json')
        with open(dataset_size_filename, 'w') as fp:
            json.dump({
                'dataset size': sink.get_total(),
                'num_classes': len(class_to_idx),
                'class_to_idx': class_to_idx,
            }, fp)

        sink.close()
        pbar.close()

全ワーカーの処理が終わったら後処理です.

  • jsonファイルにファイル数やクラス数,クラス番号対応表を保存しておきます.shardを読み込むときにこれらの情報があると便利です.
  • sinkとpbarをclose()しておきます.(managerオブジェクトにはwith文が使えないのが不便)

workerプロセス:

では各ワーカーの処理を解説します.

workerの引数
def worker(file_path, lock, pbar, sink, class_to_idx):
  • file_path:メインプロセスから渡されたjpegファイル名
  • lock, pbar, sink:managerオブジェクトで共有されているインスタンス
  • class_to_idx:メインプロセスから渡されたクラス番号対応辞書
    # 'ForkPoolWorker-21' --> 21
    # https://docs.python.org/ja/3/library/multiprocessing.html#multiprocessing.Process.name
    worker_id = int(current_process().name.split('-')[-1])
  • 自分のワーカー番号を知りたい場合には,
    multiprocessing.current_process()を使います.
    • ワーカーIDの文字列はProcess.nameで取得できるので,その最後の番号をintに変換しています.
    # when file_path == 'dataset/cats_dogs/PetImages/Dog/10247.jpg'
    category_name = file_path.parent.name  # 'Dog': str
    label = class_to_idx[category_name]  # 1: int
    key_str = category_name + '/' + file_path.stem  # 'Dog/10247': str
  • ファイル名のフルパスからカテゴリ名を抜き出す
  • クラス番号labelに変換
    • 0-originのクラス番号はたぶん学習に使います
  • shard用のkeyに変換
    • 一意ならなんでもOKですが,「カテゴリ名/サンプルID」にするのがおすすめ.shardのtarファイルを展開したときにそのディレクトリ構造が保持されたままになるので,後でなにか作業をしたいときに便利.
    try:
        Image.open(str(file_path))  # check if corrupted
    except Exception:
        return  # skip when error

ここでjpegファイルが破損しているかどうかチェック.
PIL.Image.open()でopenできなければ多分破損してるので,shardに書き込む前にここで除外してしまいます.

    with open(str(file_path), 'rb') as raw_bytes:
        buffer = raw_bytes.read()

jpegファイルをバイナリのまま読み込みます.これは以下の記事を参照.

    with lock:

        sample_dic = {
            '__key__': key_str,
            'json': json.dumps({
                'write worker id': current_process().name,
                'count': sink.get_count(),
                'category': category_name,
                'label': label,
            }),
            'jpg': buffer
        }
        sink.write(sample_dic)

        pbar.update(1)
        pbar.set_postfix_str(
            f'shard {sink.get_shards()} '
            f'worker {worker_id}'
        )

ここがshard書き込み部分です.

  • 他のワーカーと競合しないように,まずlockします.
  • shardに書き込む辞書を作成します.
    • '__key__':必須キー.
    • 'json':いろいろな情報を__key__.jsonという名称のjsonファイルとしてtarに書き込みます.
    • 'jpg':jpegバイナリを__key__.jpgという名称のjpegファイル(バイナリファイル)としてtarに書き込みます.
  • tqdmプログレスバーを更新します
    • jpegファイルを1つ処理したのでupdate(1)
    • set_postfix_str()で,現在のshard番号とworker idをプログレスバーに表示.

では実行

データセットのディレクトリ構造は以下のようなものを仮定します.

データセットのディレクトリ構造
├── dataset
│   ├── category1
│   │   ├── sample1-1.jpeg
│   │   ├── sample1-2.jpeg
...
│   ├── category2
│   │   ├── sample2-1.jpeg
│   │   ├── sample2-2.jpeg
...
│   ├── category3
│   │   ├── sample3-1.jpeg
│   │   ├── sample3-2.jpeg
...

では実行します.8ワーカー並列の例です.

python make_jpeg_shards.py -d ./dataset/ -s ./shards/ --max_size 100000000 --max_count 1000 -w 8
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0