これはwebdatasetの使い方の続編です.
webdatasetは複数のshardと呼ばれるtarファイルにデータを固めて読み込むライブラリです.shardの作り方は上の記事に書いたとおりですが,大量のデータをshardに固める場合に並列処理をしたくなります.
ということで,この記事ではshardをmultiprocessingで作成する方法を紹介します.概要は以下の通り.
- 画像がカテゴリ名のサブフォルダ以下に保存されているとします.
- 複数のworkerをmultiprocessingで作成し,shardを作成します.
- 画像の読み込み処理などは別々のworkerで行いますが,shardの書き込みは単一のwriterインスタンスで扱います.
- こうするとshard管理が一箇所で行えます.(ファイル名インクリメントなど)
- まあ別々のファイル名にしても問題ないので,別々のworkerで別々のwriterオブジェクトを作成してもいいのですが.
全体のコードは長いので折りたたんでおきます.
from functools import partial
from pathlib import Path
from tqdm import tqdm
import argparse
import json
import os
import random
from webdataset import ShardWriter
from multiprocessing import Pool, current_process
from multiprocessing.managers import SyncManager
from PIL import Image
class MyShardWriter(ShardWriter):
def __init__(self, pattern, maxcount=100000, maxsize=3e9, post=None, start_shard=0, **kw):
super().__init__(pattern, maxcount, maxsize, post, start_shard)
self.verbose = False
def get_shards(self):
return self.shard
def get_count(self):
return self.count if self.count < self.maxcount else 0
def get_total(self):
return self.total
class MyManager(SyncManager):
pass
def worker(file_path, lock, pbar, sink, class_to_idx):
# 'ForkPoolWorker-21' --> 21
# https://docs.python.org/ja/3/library/multiprocessing.html#multiprocessing.Process.name
worker_id = int(current_process().name.split('-')[-1])
#
# process
#
# when file_path == 'dataset/cats_dogs/PetImages/Dog/10247.jpg'
category_name = file_path.parent.name # 'Dog': str
label = class_to_idx[category_name] # 1: int
key_str = category_name + '/' + file_path.stem # 'Dog/10247': str
try:
Image.open(str(file_path)) # check if corrupted
except Exception:
return # skip when error
with open(str(file_path), 'rb') as raw_bytes:
buffer = raw_bytes.read()
#
# write
#
with lock:
sample_dic = {
'__key__': key_str,
'json': json.dumps({
'write worker id': current_process().name,
'count': sink.get_count(),
'category': category_name,
'label': label,
}),
'jpg': buffer
}
sink.write(sample_dic)
pbar.update(1)
pbar.set_postfix_str(
f'shard {sink.get_shards()} '
f'worker {worker_id}'
)
def make_shards(args):
file_paths = [
path for path in Path(args.data_path).glob('**/*')
if not path.is_dir()
]
if args.shuffle:
random.shuffle(file_paths)
n_samples = len(file_paths)
# https://github.com/pytorch/vision/blob/a8bde78130fd8c956780d85693d0f51912013732/torchvision/datasets/folder.py#L36
class_list = sorted(
entry.name for entry in os.scandir(args.data_path)
if entry.is_dir())
class_to_idx = {cls_name: i for i, cls_name in enumerate(class_list)}
shard_dir_path = Path(args.shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / f'{args.shard_prefix}-%05d.tar')
# https://qiita.com/tttamaki/items/96b65e6555f9d255ffd9
MyManager.register('Tqdm', tqdm)
MyManager.register('Sink', MyShardWriter)
with MyManager() as manager:
#
# prepare manager objects
#
lock = manager.Lock()
pbar = manager.Tqdm(
total=n_samples,
position=0,
)
pbar.set_description('Main process')
sink = manager.Sink(
pattern=shard_filename,
maxsize=args.max_size,
maxcount=args.max_count,
)
#
# create worker pool
#
worker_with_args = partial(
worker, lock=lock, pbar=pbar, sink=sink,
class_to_idx=class_to_idx
)
with Pool(processes=args.num_workers) as pool:
# https://stackoverflow.com/questions/26520781/multiprocessing-pool-whats-the-difference-between-map-async-and-imap
# for _ in pool.imap_unordered(
# worker_with_args,
# file_paths,
# chunksize=n_samples // args.num_workers
# ):
# pass
pool.map(
worker_with_args,
file_paths,
chunksize=n_samples // args.num_workers
)
#
# write json of dataset size
#
dataset_size_filename = str(
shard_dir_path / f'{args.shard_prefix}-dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
'dataset size': sink.get_total(),
'num_classes': len(class_to_idx),
'class_to_idx': class_to_idx,
}, fp)
sink.close()
pbar.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data_path', action='store',
help='Path to the dataset dir with category subdirs.')
parser.add_argument('-s', '--shard_path', action='store',
default='./test_shards/',
help='Path to the dir to store shard tar files.')
parser.add_argument('-p', '--shard_prefix', action='store',
default='test',
help='Prefix of shard tar files.')
parser.add_argument('--shuffle', dest='shuffle', action='store_true',
help='use shuffle')
parser.add_argument('--no_shuffle', dest='shuffle', action='store_false',
help='do not use shuffle')
parser.set_defaults(shuffle=True)
parser.add_argument('--max_size', type=float, default=100000,
help='Max size [B] of each shard tar file. '
'default to 100000 bytes.')
parser.add_argument('--max_count', type=int, default=100,
help='Max number of entries in each shard tar file. '
'default to 100.')
parser.add_argument('-w', '--num_workers', type=int, default=4,
help='Number of workers. '
'default to 4.')
args = parser.parse_args()
make_shards(args)
ではコードを解説していきます.
まずはmain()
argparseの引数処理です.一般的な引数を設定してます.
-
-d
:データセットのディレクトリ.jpeg画像ファイルがカテゴリ名毎のサブフォルダに保存されている. -
-s
:shard保存先.
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data_path', action='store',
help='Path to the dataset dir with category subdirs.')
parser.add_argument('-s', '--shard_path', action='store',
default='./test_shards/',
help='Path to the dir to store shard tar files.')
parser.add_argument('-p', '--shard_prefix', action='store',
default='test',
help='Prefix of shard tar files.')
parser.add_argument('--shuffle', dest='shuffle', action='store_true',
help='use shuffle')
parser.add_argument('--no_shuffle', dest='shuffle', action='store_false',
help='do not use shuffle')
parser.set_defaults(shuffle=True)
parser.add_argument('--max_size', type=float, default=100000,
help='Max size [B] of each shard tar file. '
'default to 100000 bytes.')
parser.add_argument('--max_count', type=int, default=100,
help='Max number of entries in each shard tar file. '
'default to 100.')
parser.add_argument('-w', '--num_workers', type=int, default=4,
help='Number of workers. '
'default to 4.')
args = parser.parse_args()
make_shards(args)
全ワーカーと共有するための準備
今回は,メインプロセスとワーカープロセスで
- shardの書き込みを管理するwebdataset.ShardWriter
- 進捗を表示するtqdm
のインスタンスをmultiprocessing.managers.SyncManagerを使って共有します.そのためにまず
- ShardWriterの派生クラスを定義.SyncManagerで共有するとattributeにアクセスできないようなので,shard番号やtotalサイズなどを返すだけのメソッドを追加しただけです.
- SyncManagerの派生クラスを定義.これを使ってワーカー間でオブジェクトを共有します.これについては以下の記事を参照.
class MyShardWriter(ShardWriter):
def __init__(self, pattern, maxcount=100000, maxsize=3e9, post=None, start_shard=0, **kw):
super().__init__(pattern, maxcount, maxsize, post, start_shard)
self.verbose = False
def get_shards(self):
return self.shard
def get_count(self):
return self.count if self.count < self.maxcount else 0
def get_total(self):
return self.total
class MyManager(SyncManager):
pass
メインプロセス:shard情報・共有オブジェクトの準備とworkerの起動
それではデータセットを書き込む準備をします.
def make_shards(args):
file_paths = [
path for path in Path(args.data_path).glob('**/*')
if not path.is_dir()
]
if args.shuffle:
random.shuffle(file_paths)
n_samples = len(file_paths)
- args.data_pathで与えられたディレクトリ以下のファイル(サブディレクトリを除く)をすべて列挙してfile_pathsに追加します.
- globで取得されるファイルの順序は任意なので,shuffleするならこの時点でシャッフルしておきます.
# https://github.com/pytorch/vision/blob/a8bde78130fd8c956780d85693d0f51912013732/torchvision/datasets/folder.py#L36
class_list = sorted(
entry.name for entry in os.scandir(args.data_path)
if entry.is_dir())
class_to_idx = {cls_name: i for i, cls_name in enumerate(class_list)}
- args.data_pathで与えられたディレクトリ以下のサブディレクトリはカテゴリ名であると仮定しているので,これらを列挙してenumerateでカテゴリ番号に変換します.
- 変換テーブルがclass_to_idx
shard_dir_path = Path(args.shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / f'{args.shard_prefix}-%05d.tar')
- shard保存先のディレクトリを設定します(なければ作成)
- shardであるtarファイル名を設定します
# https://qiita.com/tttamaki/items/96b65e6555f9d255ffd9
MyManager.register('Tqdm', tqdm)
MyManager.register('Sink', MyShardWriter)
with MyManager() as manager:
- multiprocessing.managers.SyncManagerで共有するクラス名を登録します.ここでは
- 複数ワーカーの進捗を一つのtqdmプログレスバーで表すために,tqdmを共有登録します
- 複数ワーカーが一つのsharedWriterで書き込むため,MyShardWriterを共有登録します
- with文以降でmanagerオブジェクトが起動します.
lock = manager.Lock()
pbar = manager.Tqdm(
total=n_samples,
position=0,
)
pbar.set_description('Main process')
sink = manager.Sink(
pattern=shard_filename,
maxsize=args.max_size,
maxcount=args.max_count,
)
実際に共有するインスタンスを生成します.
- lock:ロックオブジェクト.一つのインスタンスに複数ワーカーがアクセスするために必要.
- pbar:tqdmのプログレスバー.workerがjpegファイルを読み込んで,shard writerで書き込んだらupdateされる.
- sink:shard writerインスタンス.
ちなみにtqdmにはmultiprocessingと組み合わせるためのtqdm.set_lock()
があるのですが,これは複数のtqdmが一つのlockを共有する(つまり複数のプログレスバーが同時に更新されないようにする)もので,ここでやっている一つのtqdmを複数プロセスが更新するというのとは異なります.
worker_with_args = partial(
worker, lock=lock, pbar=pbar, sink=sink,
class_to_idx=class_to_idx
)
mapでworkerに渡せる引数は1つなのですが,それはメインプロセスからワーカーへ送るjpegファイル名に使ってしまいます.そうすると追加の引数として共有インスタンスを渡せません.
そこでfunctools.partialを使って,引数付きの関数オブジェクトを作ってしまいます.こうすると,引数を固定した関数オブジェクトをmapに渡せばOKです.
with Pool(processes=args.num_workers) as pool:
# for _ in pool.imap_unordered(
# worker_with_args,
# file_paths,
# chunksize=n_samples // args.num_workers
# ):
# pass
pool.map(
worker_with_args,
file_paths,
chunksize=n_samples // args.num_workers
)
では実際にワーカーを生成します.
- Pool()でワーカープールpoolを作成します.
- pool.map()でワーカーを起動します.
- 一般的にはpool.imap_unorderedを使うことが多いようなのですが,今回はmapを使います.その理由は
- ワーカーに渡すのはファイル名リストfile_pathsなので遅延評価は必要ない(つまりimapではなくてmapでOK)
- ワーカーはshardに書き込むだけなのでメインプロセスに返ってくるものはないし,メインがブロックされてもよい(つまりmap_asyncやimap_unorderedではなくて単に同期mapでOK)
- worker_with_argsが並列に処理される関数
- file_pathsの各要素がワーカーに渡される
- chunksizeはワーカーに渡される要素チャンクの長さ.ここでは全データ数をワーカー数で割った数(商)にしています.
- 10ワーカー,100サンプルならchunksize=10で,各ワーカーが10サンプルずつで終了.
- 10ワーカー,103サンプルならchunksize=19で,各ワーカーに10サンプルずつ渡されて,残った3サンプルは処理が終わったワーカーに追加で渡される.
- 一般的にはpool.imap_unorderedを使うことが多いようなのですが,今回はmapを使います.その理由は
dataset_size_filename = str(
shard_dir_path / f'{args.shard_prefix}-dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
'dataset size': sink.get_total(),
'num_classes': len(class_to_idx),
'class_to_idx': class_to_idx,
}, fp)
sink.close()
pbar.close()
全ワーカーの処理が終わったら後処理です.
- jsonファイルにファイル数やクラス数,クラス番号対応表を保存しておきます.shardを読み込むときにこれらの情報があると便利です.
- sinkとpbarをclose()しておきます.(managerオブジェクトにはwith文が使えないのが不便)
workerプロセス:
では各ワーカーの処理を解説します.
def worker(file_path, lock, pbar, sink, class_to_idx):
- file_path:メインプロセスから渡されたjpegファイル名
- lock, pbar, sink:managerオブジェクトで共有されているインスタンス
- class_to_idx:メインプロセスから渡されたクラス番号対応辞書
# 'ForkPoolWorker-21' --> 21
# https://docs.python.org/ja/3/library/multiprocessing.html#multiprocessing.Process.name
worker_id = int(current_process().name.split('-')[-1])
- 自分のワーカー番号を知りたい場合には,
multiprocessing.current_process()を使います.- ワーカーIDの文字列はProcess.nameで取得できるので,その最後の番号をintに変換しています.
# when file_path == 'dataset/cats_dogs/PetImages/Dog/10247.jpg'
category_name = file_path.parent.name # 'Dog': str
label = class_to_idx[category_name] # 1: int
key_str = category_name + '/' + file_path.stem # 'Dog/10247': str
- ファイル名のフルパスからカテゴリ名を抜き出す
- クラス番号labelに変換
- 0-originのクラス番号はたぶん学習に使います
- shard用のkeyに変換
- 一意ならなんでもOKですが,「カテゴリ名/サンプルID」にするのがおすすめ.shardのtarファイルを展開したときにそのディレクトリ構造が保持されたままになるので,後でなにか作業をしたいときに便利.
try:
Image.open(str(file_path)) # check if corrupted
except Exception:
return # skip when error
ここでjpegファイルが破損しているかどうかチェック.
PIL.Image.open()でopenできなければ多分破損してるので,shardに書き込む前にここで除外してしまいます.
with open(str(file_path), 'rb') as raw_bytes:
buffer = raw_bytes.read()
jpegファイルをバイナリのまま読み込みます.これは以下の記事を参照.
with lock:
sample_dic = {
'__key__': key_str,
'json': json.dumps({
'write worker id': current_process().name,
'count': sink.get_count(),
'category': category_name,
'label': label,
}),
'jpg': buffer
}
sink.write(sample_dic)
pbar.update(1)
pbar.set_postfix_str(
f'shard {sink.get_shards()} '
f'worker {worker_id}'
)
ここがshard書き込み部分です.
- 他のワーカーと競合しないように,まずlockします.
- shardに書き込む辞書を作成します.
-
'__key__'
:必須キー. -
'json'
:いろいろな情報を__key__.json
という名称のjsonファイルとしてtarに書き込みます. -
'jpg'
:jpegバイナリを__key__.jpg
という名称のjpegファイル(バイナリファイル)としてtarに書き込みます.
-
- tqdmプログレスバーを更新します
- jpegファイルを1つ処理したので
update(1)
-
set_postfix_str()
で,現在のshard番号とworker idをプログレスバーに表示.
- jpegファイルを1つ処理したので
では実行
データセットのディレクトリ構造は以下のようなものを仮定します.
├── dataset
│ ├── category1
│ │ ├── sample1-1.jpeg
│ │ ├── sample1-2.jpeg
...
│ ├── category2
│ │ ├── sample2-1.jpeg
│ │ ├── sample2-2.jpeg
...
│ ├── category3
│ │ ├── sample3-1.jpeg
│ │ ├── sample3-2.jpeg
...
では実行します.8ワーカー並列の例です.
python make_jpeg_shards.py -d ./dataset/ -s ./shards/ --max_size 100000000 --max_count 1000 -w 8