1. webdatasetとは
webdatasetとは,データセットをtarアーカイブで読み書きするためのライブラリです.
WebDataset reads dataset that are stored as tar files, with the simple convention that files that belong together and make up a training sample share the same basename.
つまり,tarファイルの中に
n01440764/ILSVRC2012_val_00000293.cls
n01440764/ILSVRC2012_val_00000293.jpg
n01440764/ILSVRC2012_val_00002138.cls
n01440764/ILSVRC2012_val_00002138.jpg
n01440764/ILSVRC2012_val_00003014.cls
n01440764/ILSVRC2012_val_00003014.jpg
というファイルが入っていたら,同じbasenameを持つ複数のファイルを1つの学習サンプルデータとみなします(最初のサンプルの場合にはn01440764/ILSVRC2012_val_00000293
ががbasename).この場合には3つの学習データがあり,jpgとclsというペアが1つのサンプルになっています.
この形式は,PyTorchのtorchdataやNVIDIAのdaliなども読み込みを対応しています.
大量のデータを扱う場合にはtarで固めておいたほうが便利なときも多いので,使い方次第でいろいろ工夫ができます.が,英語でも日本語でもほとんど使い方の情報がありません.見つかったとしても,tarアーカイブを作るにはtar
やtarp
を使えばよい,読み込みもjpgとclsというペアのサンプルしかない,など,工夫を凝らしたshard作成やデコーダを作成するには情報が足りません.
- PyTorchで深層学習データセットを効率的に取り扱うために
- 1ペタバイトのデータセットで機械学習する / WebDataset入門
- git@github.com:tmbdev-archive/webdataset-examples.git
そこでこの記事では,いろいろなshardの生成方法や,デコーダなどの作成方法を紹介することにします.
1.1. ちなみにtarとは?
2. webdatasetのインストール
まずはpipでwebdatasetをインストールします.
!pip install git+https://github.com/webdataset/webdataset.git
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/webdataset/webdataset.git
Cloning https://github.com/webdataset/webdataset.git to /tmp/pip-req-build-up9_m1oz
Running command git clone -q https://github.com/webdataset/webdataset.git /tmp/pip-req-build-up9_m1oz
Installing build dependencies ... [?25l[?25hdone
Getting requirements to build wheel ... [?25l[?25hdone
Installing backend dependencies ... [?25l[?25hdone
Preparing wheel metadata ... [?25l[?25hdone
Requirement already satisfied: braceexpand in /usr/local/lib/python3.7/dist-packages (from webdataset==0.2.20) (0.1.7)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from webdataset==0.2.20) (6.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from webdataset==0.2.20) (1.21.6)
もしくは普通にpipから.
!pip install webdataset
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting webdataset
Downloading webdataset-0.2.20-py3-none-any.whl (49 kB)
[K |████████████████████████████████| 49 kB 3.0 MB/s
[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from webdataset) (1.21.6)
Collecting braceexpand
Downloading braceexpand-0.1.7-py2.py3-none-any.whl (5.9 kB)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from webdataset) (6.0)
Installing collected packages: braceexpand, webdataset
Successfully installed braceexpand-0.1.7 webdataset-0.2.20
webdatasetのインポートではwdsと略します.
import webdataset as wds
3. データの準備
では例題用の学習データを用意します.小さいtoy exampleがよいので,ここではImageNetのサブセットであるimagenetteの最小画像サイズを利用します.
import requests
from pathlib import Path
import tarfile
def download(url):
filename = Path(url).name
with open(filename, 'wb') as save_file:
save_file.write(requests.get(url).content)
return filename
def tar_xzvf(filename, path='.'):
with tarfile.open(filename, 'r:gz') as tar_file:
tar_file.extractall(path=path)
def tar_tvf(filename):
with tarfile.open(filename, 'r:') as tar_file:
for i in range(10):
info = tar_file.next()
print(f'{info.gname}/{info.uname} {info.size:8d} {info.name}')
filename = download('https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz')
tar_xzvf(filename)
これで./imagenette2-160/train/
以下に学習画像が保存されました.
4. pytorch標準のデータローダーの復習
では通常のpytorchのデータローダーの典型例を書いてみましょう.
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
単純にImageFolderを使った場合,root
に画像フォルダを指定すれば,それ以下のサブフォルダはカテゴリとみなされてラベルは自動的に生成されます.
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
dataset = ImageFolder(
'./imagenette2-160/train/',
transform=transform)
for i, (img, label) in enumerate(dataset):
print(img.shape, label)
if i > 5:
break
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
データローダーを使った場合.こちらが普通ですね.
data_loader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
)
for i, (img, label) in enumerate(data_loader):
print(img.shape, label)
if i > 5:
break
torch.Size([4, 3, 224, 224]) tensor([4, 7, 8, 6])
torch.Size([4, 3, 224, 224]) tensor([6, 5, 4, 0])
torch.Size([4, 3, 224, 224]) tensor([7, 9, 0, 5])
torch.Size([4, 3, 224, 224]) tensor([0, 9, 7, 3])
torch.Size([4, 3, 224, 224]) tensor([2, 8, 6, 9])
torch.Size([4, 3, 224, 224]) tensor([2, 2, 1, 0])
torch.Size([4, 3, 224, 224]) tensor([7, 3, 2, 7])
5. webdatasetのshardの作り方
では同じものをwebdatasetでやってみます.
まずはデータの準備です.上記のような構造のフォルダに保存されている画像を読み込んで,複数のtarファイルにまとめていきます.ちなみにそれぞれのそれぞれのtarファイルをshard(断片)と呼びます.
shardを作ったらローダーを作りますが,まずshardをどのように作成するのかを説明しましょう.準備として,以下のようにjpeg画像ファイルのリストと,カテゴリ名のリストを作成します.
import os
import random
dataset_root = './imagenette2-160/train/'
file_paths = [
path for path in Path(dataset_root).glob('*/*')
if not path.is_dir()
and path.name.endswith((
'.JPEG', '.jpeg', '.jpg',
))
]
random.shuffle(file_paths)
print(file_paths[:2])
category_list = sorted([
path.name for path in Path(dataset_root).glob('*') if path.is_dir()
])
category_index = {
category_name: i
for i, category_name in enumerate(category_list)
}
category_index
[PosixPath('imagenette2-160/train/n02102040/n02102040_3699.JPEG'), PosixPath('imagenette2-160/train/n03425413/n03425413_12249.JPEG')]
{'n01440764': 0,
'n02102040': 1,
'n02979186': 2,
'n03000684': 3,
'n03028079': 4,
'n03394916': 5,
'n03417042': 6,
'n03425413': 7,
'n03445777': 8,
'n03888257': 9}
後にも説明しますが,shardに書き込む時点でランダムにしておいたほうが都合が良いです.そのためそのためfile_paths
をここでシャッフルしています.
5.1. shard作成方法1
ではshardを書き出します.
まずは,画像ファイルを読み込み,ndarrayにしてから,webdatasetのjpegエンコーダを利用する方法です.
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import json
shard_path = './shards_01'
shard_dir_path = Path(shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / 'shards-%05d.tar')
shard_size = int(50 * 1000**2) # 50MB each
with wds.ShardWriter(
shard_filename,
maxsize=shard_size,
) as sink, tqdm(
file_paths
) as pbar:
for file_path in pbar:
category_name = file_path.parent.name
label = category_index[category_name]
key_str = category_name + '/' + file_path.stem
sink.write({
"__key__": key_str,
"jpg": np.array(Image.open(file_path)),
"cls": label,
})
dataset_size = len(shard_filename)
dataset_size_filename = str(
shard_dir_path / 'dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
"dataset size": dataset_size,
"n_classes": len(category_index),
}, fp)
# writing shards_01/shards-00000.tar 0 0.0 GB 0
0%| | 0/9469 [00:00<?, ?it/s]
# writing shards_01/shards-00001.tar 2363 0.1 GB 2363
# writing shards_01/shards-00002.tar 2375 0.1 GB 4738
# writing shards_01/shards-00003.tar 2425 0.1 GB 7163
5.1.1. shard作成の詳しい説明
これ以降の例は上記のコードをベースにしているので,ここで詳しく順を追って説明します.
shard_path = './shards_01'
shard_dir_path = Path(shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / 'shards-%05d.tar')
shardを保存するディレクトリとファイル名を設定します.pathlib
を使っているのでので見慣れないかもしれませんが,ディレクトリを作成して,ファイル名文字列を生成しているだけです.shardファイルは複数生成されるので,文字列には出力指定子を使います.
shard_size = int(50 * 1000**2) # 50MB each
with wds.ShardWriter(
shard_filename,
maxsize=shard_size,
) as sink, tqdm(
file_paths
) as pbar:
with
文でShardWriter
を生成します.
- 1つ目の引数は出力指定子付きの保存ファイル名パターンを指定子ます.
- 2つ目の
maxsize
には,一つのshard(tarファイル)のサイズを指定します(バイト単位).たとえばshardにまとめるデータが130MBあった場合,maxsizeを50MBにすると,50MB, 50MB, 30MBという3つのtarファイルが生成されます.
ファイル保存の進捗を表示したいので,tqdm
も使います.(同時にwith
で指定しています)
for file_path in pbar:
category_name = file_path.parent.name
label = category_index[category_name]
key_str = category_name + '/' + file_path.stem
ここでは文字列操作をしています..
-
category_name
には,フルパスのファイル名から,カテゴリ名であるディレクトリ名を抜き出して保存しています. -
label
には,カテゴリ番号のintを保存します.. -
key_str
には,tarファイル中ののbasenameとなる文字列を指定子ます.file_path.stem
だけでも良いのですが,ここではカテゴリ名も付け加えておきます.
sink.write({
"__key__": key_str,
"jpg": np.array(Image.open(file_path)),
"cls": label,
})
この部分が実際にtarファイルに追記する部分です.write
にはdict型を与えます.dictには"__key__"
というキーが必須です.この値が,tarファイル中のbasenameになります.
それ以外のキー(指定の仕方については後述します)は,tarファイル中の拡張子として使用されます.上記の例で
category_name == "n01440764"
file_path.stem == "ILSVRC2012_val_00000293"
key_str == "n01440764/ILSVRC2012_val_00000293"
となっているとすると,
n01440764/ILSVRC2012_val_00000293.jpg
n01440764/ILSVRC2012_val_00000293.cls
という2つのファイルがtarファイル中に書き込まれることになります.
これを,学習サンプルの文だけ反復します.
tarファイルにどんどんとサンプルを追加していきますが,maxsize
を超えると,そのtarファイルはクローズされて,次からはカウンタが一つ増えた新しいtarファイルに追加していきます.
ここまでがshardであるtarファイルの作成方法です.
dataset_size = len(shard_filename)
dataset_size_filename = str(
shard_dir_path / 'dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
"dataset size": dataset_size,
"n_classes": len(category_index),
}, fp)
最後に,データセット中の学習サンプル数(dataset size)を別途保存しておきます.形式は何でも良いのですが,ここでは例としてjsonで保存しています.
webdatasetのshardをロードするときには,サンプル数を明示的に与える必要があります.もしImageFolder
などを使うディレクトリ中のファイルをカウントすればサンプル数が分かるので,与える必要はありません.しかしwebdatasetの場合には,学習データは複数のshardに分かれていて,さらに各tarファイルの中にはいくつのサンプルが含まれているのかは,すべて読み込んでからでなければ分かりません(これではtarから読み込みながらその都度学習,ということができません).
5.1.2. shardの書き込みに使うキー
write
に与えるdictのキーにはいくつかのルールがあります.
-
__key__
:必須.basenameとして使用されます(前述の通り)
それ以外のキーは,以下のように解釈されます.
5.1.2.1. JPEG, PNGなど画像の拡張子
imgがPIL.Image
やnp.ndarray
の場合には,キーに指定したフォーマットの画像にエンコードされます.例えば
sink.write({
"__key__": key_str,
"jpg": img,
})
ならjpeg画像ファイルとしてtarに書き込まれますし,
sink.write({
"__key__": key_str,
"png": img,
})
ならpng画像ファイルになります.つまりキーは,tarに書き込まれるファイルの拡張子になるだけでなく,エンコーダへのヒントとしても使われます.
5.1.2.2. その他のキー
-
cls
: webdatasetのサンプルコードでよく出てくるのは"cls"
ですが,この場合には値は単に文字列としてファイルに保存されます. -
pickle
: このキーで指定されたオブジェクトはpickleファイルとして書き込まれます.
ドキュメントには説明が見当たらないですが,コードを見ればどんなキーが使えるのかがわかります.
def make_handlers():
"""Create a list of handlers for encoding data."""
handlers = {}
add_handlers(
handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
)
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
add_handlers(handlers, "pyd pickle", pickle.dumps)
add_handlers(handlers, "pth", torch_dumps)
add_handlers(handlers, "npy", numpy_dumps)
add_handlers(handlers, "npz", numpy_npz_dumps)
add_handlers(handlers, "ten tenbin tb", tenbin_dumps)
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
add_handlers(handlers, "mp msgpack msg", mp_dumps)
add_handlers(handlers, "cbor", cbor_dumps)
add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
add_handlers(handlers, "ppm", lambda data: imageencoder(data, "ppm"))
return handlers
つまりキーにcls
とかいてもid
と書いても同じ扱いですね.
5.1.2.3. 文字列+拡張子のキー
キーをjpg
とすると,basename + ".jpg"
というjpegファイルがtarに書き込まれますが,キーをimage.jpg
とすると,basename + "image.jpg"
というjpegファイルがtarに書き込まれます.
つまりキー文字列のsuffix(.
以降)がフォーマットのヒントとして使われます.一つの学習サンプルが複数のjpeg画像を持つような場合には,例えばimage1.jpg
とimage2.jpg
という2つのキーで2枚の画像を保存します.
なお.
が複数ある場合には,最初の.
以前がstemとして(つまりキーとして),最後の.
以降が拡張子として扱われるようです.あとの例でも出てきますが,foobar.img.jpg
やfoobar.label.jpg
などのファイルをtarに書き込むと,デコード時には,同じfoobar
をキーとして持つ2つの画像として扱われます.
5.1.2.4. 用意されていないフォーマットの場合
デフォルトで用意されているのは上記のものだけなので,これら以外のデータを書き込む場合には,pickleオブジェクトにします.
もしくは,バイト列として保存します.どんなキーを使ったとしても,書き込む値の方がbyteオブジェクトであれば,そのままバイト列として保存されるようです.コードを見ると以下のようになっています.
if isinstance(data, bytes):
return data
これを使うと,書き込みたいファイルが何であれ(例えば動画.mp4や音声.mp3などでも),byteオブジェクトとして読み込み,そのまま書き込むことができます.
with open(filename, "rb") as f:
movie_bytes = f.read()
sink.write({
"__key__": key_str,
"mp4": movie_bytes,
})
なおwebdatasetのwrite
の書き込みにはTarFile.addを使っているため,上記のmovie_bytes
には一旦ファイル全体を読み込む必要があります.数百GBのファイルを扱うような場合にはメモリ容量に注意してください.
更に別の方法として,encoderを自作してコンストラクタに与えるという方法もあります.説明は省略するので頑張って自作してください.
class TarWriter:
...
def __init__(
...
encoder: Union[None, bool, Callable] = True,
...
):
これでshardができました.tar
コマンドで中身を見てみましょう.
tar_tvf('shards_01/shards-00000.tar')
bigdata/bigdata 1 n02102040/n02102040_3699.cls
bigdata/bigdata 17933 n02102040/n02102040_3699.jpg
bigdata/bigdata 1 n03425413/n03425413_12249.cls
bigdata/bigdata 24201 n03425413/n03425413_12249.jpg
bigdata/bigdata 1 n03888257/n03888257_6468.cls
bigdata/bigdata 18911 n03888257/n03888257_6468.jpg
bigdata/bigdata 1 n01440764/n01440764_769.cls
bigdata/bigdata 26733 n01440764/n01440764_769.jpg
bigdata/bigdata 1 n02979186/n02979186_14197.cls
bigdata/bigdata 22348 n02979186/n02979186_14197.jpg
確認のために画像を1枚読み込んで表示してみます.
import matplotlib.pyplot as plt
%matplotlib inline
shard_filename = 'shards_01/shards-00000.tar'
with tarfile.open(shard_filename, 'r') as tar_file:
filename = tar_file.next().name
key = filename.split('.')[0] # stem
img_buffer = tar_file.extractfile(key + '.jpg')
img = np.array(Image.open(img_buffer))
cls_buffer = tar_file.extractfile(key + '.cls')
label = int(cls_buffer.read())
plt.imshow(img)
print('label:', label)
label: 1
上記の読み込みは,webdatasetのデコーダでは自動的に行われる処理が多いので,実際に使用する際には必要はありません.
-
tar_file.extractfile
で読み込むとio.BufferedReader
が返ってきます -
PIL.Image.open()
ならそれを直接指定して読み込みます -
cls
のほうは,read()
で読んだbytesをintに変換しています
5.2. shard作成方法2
次のshard書き出し方法は,画像ファイルをそのまま何もせずににバイト列のままshardに書き込む方法です.jpegをデコード・エンコードしないので,shard作成が速い・ファイル容量はそのまま,です.
from tqdm.auto import tqdm
import io
shard_path = './shards_02'
shard_dir_path = Path(shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / 'shards-%05d.tar')
print('shards are saved as', shard_filename)
shard_size = int(50 * 1000**2) # 50MB each
with wds.ShardWriter(
shard_filename,
maxsize=shard_size,
) as sink, tqdm(
file_paths
) as pbar:
for file_path in pbar:
category_name = file_path.parent.name
label = category_index[category_name]
key_str = category_name + '/' + file_path.stem
with open(file_path, 'rb') as raw_bytes:
buffer = raw_bytes.read()
sink.write({
"__key__": key_str,
"jpg": buffer,
"cls": label,
})
dataset_size = len(shard_filename)
dataset_size_filename = str(
shard_dir_path / 'dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
"dataset size": dataset_size,
"n_classes": len(category_index),
}, fp)
shards are saved as shards_02/shards-%05d.tar
# writing shards_02/shards-00000.tar 0 0.0 GB 0
0%| | 0/9469 [00:00<?, ?it/s]
# writing shards_02/shards-00001.tar 6267 0.1 GB 6267
tar_tvf('shards_02/shards-00000.tar')
bigdata/bigdata 1 n02102040/n02102040_3699.cls
bigdata/bigdata 6814 n02102040/n02102040_3699.jpg
bigdata/bigdata 1 n03425413/n03425413_12249.cls
bigdata/bigdata 9479 n03425413/n03425413_12249.jpg
bigdata/bigdata 1 n03888257/n03888257_6468.cls
bigdata/bigdata 7363 n03888257/n03888257_6468.jpg
bigdata/bigdata 1 n01440764/n01440764_769.cls
bigdata/bigdata 10216 n01440764/n01440764_769.jpg
bigdata/bigdata 1 n02979186/n02979186_14197.cls
bigdata/bigdata 8663 n02979186/n02979186_14197.jpg
import matplotlib.pyplot as plt
%matplotlib inline
shard_filename = 'shards_02/shards-00000.tar'
with tarfile.open(shard_filename, 'r') as tar_file:
filename = tar_file.next().name
key = filename.split('.')[0] # stem
img_buffer = tar_file.extractfile(key + '.jpg')
img = np.array(Image.open(img_buffer))
cls_buffer = tar_file.extractfile(key + '.cls')
label = int(cls_buffer.read())
plt.imshow(img)
print('label:', label)
label: 1
5.3. shard作成方法3
次のshard書き出し方法は,いろいろな情報を保存する方法です.サンプルコードの多くは(img, cls)ペアの説明ですが,単にtarに保存するだけなので,様々なものをshard作成時に入れることができます.
from tqdm.auto import tqdm
from PIL import Image
import io
import json
shard_path = './shards_03'
shard_dir_path = Path(shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / 'shards-%05d.tar')
print('shards are saved as', shard_filename)
shard_size = int(50 * 1000**2) # 50MB each
with wds.ShardWriter(
shard_filename,
maxsize=shard_size,
) as sink, tqdm(
file_paths
) as pbar:
for file_path in pbar:
category_name = file_path.parent.name
label = category_index[category_name]
key_str = category_name + '/' + file_path.stem
with open(file_path, 'rb') as raw_bytes:
buffer = raw_bytes.read()
img = Image.open(file_path)
path = Path(file_path)
info_dic = {
'label': label,
'width': img.width,
'height': img.height,
'info': img.info,
'format': img.format,
'format_description': img.format_description,
'category name': path.parent.name,
'ext': path.suffix,
'file id': path.stem,
'filesize': path.stat().st_size, # in bytes
}
sink.write({
"__key__": key_str,
"jpg": buffer,
"json": json.dumps(info_dic),
})
dataset_size = len(shard_filename)
dataset_size_filename = str(
shard_dir_path / 'dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
"dataset size": dataset_size,
"n_classes": len(category_index),
}, fp)
shards are saved as shards_03/shards-%05d.tar
# writing shards_03/shards-00000.tar 0 0.0 GB 0
0%| | 0/9469 [00:00<?, ?it/s]
# writing shards_03/shards-00001.tar 6050 0.1 GB 6050
tar_tvf('shards_03/shards-00000.tar')
bigdata/bigdata 6814 n02102040/n02102040_3699.jpg
bigdata/bigdata 281 n02102040/n02102040_3699.json
bigdata/bigdata 9479 n03425413/n03425413_12249.jpg
bigdata/bigdata 282 n03425413/n03425413_12249.json
bigdata/bigdata 7363 n03888257/n03888257_6468.jpg
bigdata/bigdata 281 n03888257/n03888257_6468.json
bigdata/bigdata 10216 n01440764/n01440764_769.jpg
bigdata/bigdata 281 n01440764/n01440764_769.json
bigdata/bigdata 8663 n02979186/n02979186_14197.jpg
bigdata/bigdata 282 n02979186/n02979186_14197.json
import matplotlib.pyplot as plt
%matplotlib inline
shard_filename = 'shards_03/shards-00000.tar'
with tarfile.open(shard_filename, 'r') as tar_file:
filename = tar_file.next().name
key = filename.split('.')[0] # stem
img_buffer = tar_file.extractfile(key + '.jpg')
img = np.array(Image.open(img_buffer))
info_json = tar_file.extractfile(key + '.json')
info_dic = json.loads(info_json.read())
plt.imshow(img)
print(info_dic)
print('label:', info_dic['label'])
{'label': 1, 'width': 213, 'height': 160, 'info': {'jfif': 257, 'jfif_version': [1, 1], 'jfif_unit': 0, 'jfif_density': [1, 1]}, 'format': 'JPEG', 'format_description': 'JPEG (ISO 10918)', 'category name': 'n02102040', 'ext': '.JPEG', 'file id': 'n02102040_3699', 'filesize': 6814}
label: 1
5.4. shard作成方法4
上記の方法はjpeg画像やjsonなど,ファイルにしやすいデータ構造をファイルとして保存していましたが,任意のオブジェクトを保存することもできます.ただしpickleを使うのでpickle化できることが必要ですが.
以下ではいろいろなデータ構造を,一つのpickleファイルとして保存する例です.
from tqdm.auto import tqdm
from PIL import Image
import io
shard_path = './shards_04'
shard_dir_path = Path(shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / 'shards-%05d.tar')
print('shards are saved as', shard_filename)
shard_size = int(50 * 1000**2) # 50MB each
with wds.ShardWriter(
shard_filename,
maxsize=shard_size,
) as sink, tqdm(
file_paths
) as pbar:
for file_path in pbar:
category_name = file_path.parent.name
label = category_index[category_name]
key_str = category_name + '/' + file_path.stem
with open(file_path, 'rb') as raw_bytes:
buffer = raw_bytes.read()
img = Image.open(file_path)
path = Path(file_path)
info_dic = {
'label': label,
'width': img.width,
'height': img.height,
'info': img.info,
'format': img.format,
'format_description': img.format_description,
'category name': path.parent.name,
'ext': path.suffix,
'file id': path.stem,
'filesize': path.stat().st_size, # in bytes
}
sink.write({
"__key__": key_str,
"pickle": (
buffer,
img,
path,
info_dic
)
})
dataset_size = len(shard_filename)
dataset_size_filename = str(
shard_dir_path / 'dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
"dataset size": dataset_size,
"n_classes": len(category_index),
}, fp)
shards are saved as shards_04/shards-%05d.tar
# writing shards_04/shards-00000.tar 0 0.0 GB 0
0%| | 0/9469 [00:00<?, ?it/s]
# writing shards_04/shards-00001.tar 441 0.1 GB 441
# writing shards_04/shards-00002.tar 440 0.1 GB 881
# writing shards_04/shards-00003.tar 442 0.1 GB 1323
# writing shards_04/shards-00004.tar 438 0.1 GB 1761
# writing shards_04/shards-00005.tar 449 0.1 GB 2210
# writing shards_04/shards-00006.tar 438 0.1 GB 2648
# writing shards_04/shards-00007.tar 440 0.1 GB 3088
# writing shards_04/shards-00008.tar 442 0.1 GB 3530
# writing shards_04/shards-00009.tar 438 0.1 GB 3968
# writing shards_04/shards-00010.tar 438 0.1 GB 4406
# writing shards_04/shards-00011.tar 447 0.1 GB 4853
# writing shards_04/shards-00012.tar 440 0.1 GB 5293
# writing shards_04/shards-00013.tar 450 0.1 GB 5743
# writing shards_04/shards-00014.tar 446 0.1 GB 6189
# writing shards_04/shards-00015.tar 444 0.1 GB 6633
# writing shards_04/shards-00016.tar 442 0.1 GB 7075
# writing shards_04/shards-00017.tar 445 0.1 GB 7520
# writing shards_04/shards-00018.tar 447 0.1 GB 7967
# writing shards_04/shards-00019.tar 443 0.1 GB 8410
# writing shards_04/shards-00020.tar 437 0.1 GB 8847
# writing shards_04/shards-00021.tar 452 0.1 GB 9299
tar_tvf('shards_04/shards-00000.tar')
bigdata/bigdata 109563 n02102040/n02102040_3699.pickle
bigdata/bigdata 112230 n03425413/n03425413_12249.pickle
bigdata/bigdata 122112 n03888257/n03888257_6468.pickle
bigdata/bigdata 112963 n01440764/n01440764_769.pickle
bigdata/bigdata 111414 n02979186/n02979186_14197.pickle
bigdata/bigdata 108773 n02102040/n02102040_2493.pickle
bigdata/bigdata 124458 n01440764/n01440764_2144.pickle
bigdata/bigdata 111550 n03425413/n03425413_13973.pickle
bigdata/bigdata 111751 n03028079/n03028079_77417.pickle
bigdata/bigdata 113628 n03394916/n03394916_37186.pickle
import matplotlib.pyplot as plt
%matplotlib inline
import pickle
shard_filename = 'shards_04/shards-00000.tar'
with tarfile.open(shard_filename, 'r') as tar_file:
filename = tar_file.next().name
key = filename.split('.')[0] # stem
buffer = tar_file.extractfile(key + '.pickle')
buffer, img, path, info_dic = pickle.loads(buffer.read())
print(type(buffer), type(img), type(path), type(info_dic))
plt.imshow(img)
print(info_dic)
print('label', info_dic['label'])
<class 'bytes'> <class 'PIL.JpegImagePlugin.JpegImageFile'> <class 'pathlib.PosixPath'> <class 'dict'>
{'label': 1, 'width': 213, 'height': 160, 'info': {'jfif': 257, 'jfif_version': (1, 1), 'jfif_unit': 0, 'jfif_density': (1, 1)}, 'format': 'JPEG', 'format_description': 'JPEG (ISO 10918)', 'category name': 'n02102040', 'ext': '.JPEG', 'file id': 'n02102040_3699', 'filesize': 6814}
label 1
5.5. shard作成方法5
上記の方法は一つのpickleファイルとしていろいろな情報を保存した例ですが,別々のpickleファイルに保存する事もできます.
from tqdm.auto import tqdm
from PIL import Image
import io
shard_path = './shards_05'
shard_dir_path = Path(shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / 'shards-%05d.tar')
print('shards are saved as', shard_filename)
shard_size = int(50 * 1000**2) # 50MB each
with wds.ShardWriter(
shard_filename,
maxsize=shard_size,
) as sink, tqdm(
file_paths
) as pbar:
for file_path in pbar:
category_name = file_path.parent.name
label = category_index[category_name]
key_str = category_name + '/' + file_path.stem
with open(file_path, 'rb') as raw_bytes:
buffer = raw_bytes.read()
img = Image.open(file_path)
path = Path(file_path)
info_dic = {
'label': label,
'width': img.width,
'height': img.height,
'info': img.info,
'format': img.format,
'format_description': img.format_description,
'category name': path.parent.name,
'ext': path.suffix,
'file id': path.stem,
'filesize': path.stat().st_size, # in bytes
}
sink.write({
"__key__": key_str,
"jpg": buffer,
"img.pickle": img,
"path.pickle": path,
"json": json.dumps(info_dic),
})
dataset_size = len(shard_filename)
dataset_size_filename = str(
shard_dir_path / 'dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
"dataset size": dataset_size,
"n_classes": len(category_index),
}, fp)
shards are saved as shards_05/shards-%05d.tar
# writing shards_05/shards-00000.tar 0 0.0 GB 0
0%| | 0/9469 [00:00<?, ?it/s]
# writing shards_05/shards-00001.tar 440 0.1 GB 440
# writing shards_05/shards-00002.tar 440 0.1 GB 880
# writing shards_05/shards-00003.tar 442 0.1 GB 1322
# writing shards_05/shards-00004.tar 438 0.1 GB 1760
# writing shards_05/shards-00005.tar 449 0.1 GB 2209
# writing shards_05/shards-00006.tar 438 0.1 GB 2647
# writing shards_05/shards-00007.tar 440 0.1 GB 3087
# writing shards_05/shards-00008.tar 441 0.1 GB 3528
# writing shards_05/shards-00009.tar 438 0.1 GB 3966
# writing shards_05/shards-00010.tar 438 0.1 GB 4404
# writing shards_05/shards-00011.tar 447 0.1 GB 4851
# writing shards_05/shards-00012.tar 440 0.1 GB 5291
# writing shards_05/shards-00013.tar 449 0.1 GB 5740
# writing shards_05/shards-00014.tar 445 0.1 GB 6185
# writing shards_05/shards-00015.tar 444 0.1 GB 6629
# writing shards_05/shards-00016.tar 442 0.1 GB 7071
# writing shards_05/shards-00017.tar 444 0.1 GB 7515
# writing shards_05/shards-00018.tar 448 0.1 GB 7963
# writing shards_05/shards-00019.tar 443 0.1 GB 8406
# writing shards_05/shards-00020.tar 437 0.1 GB 8843
# writing shards_05/shards-00021.tar 451 0.1 GB 9294
tar_tvf('shards_05/shards-00000.tar')
bigdata/bigdata 102404 n02102040/n02102040_3699.img.pickle
bigdata/bigdata 6814 n02102040/n02102040_3699.jpg
bigdata/bigdata 281 n02102040/n02102040_3699.json
bigdata/bigdata 107 n02102040/n02102040_3699.path.pickle
bigdata/bigdata 102404 n03425413/n03425413_12249.img.pickle
bigdata/bigdata 9479 n03425413/n03425413_12249.jpg
bigdata/bigdata 282 n03425413/n03425413_12249.json
bigdata/bigdata 108 n03425413/n03425413_12249.path.pickle
bigdata/bigdata 114404 n03888257/n03888257_6468.img.pickle
bigdata/bigdata 7363 n03888257/n03888257_6468.jpg
import matplotlib.pyplot as plt
%matplotlib inline
import pickle
shard_filename = 'shards_05/shards-00000.tar'
with tarfile.open(shard_filename, 'r') as tar_file:
filename = tar_file.next().name
key = filename.split('.')[0] # stem
path = pickle.loads(tar_file.extractfile(key + '.path.pickle').read())
img_pil = pickle.loads(tar_file.extractfile(key + '.img.pickle').read())
img = np.array(Image.open(tar_file.extractfile(key + '.jpg')))
info_dic = json.loads(tar_file.extractfile(key + '.json').read())
plt.imshow(img)
plt.show()
plt.imshow(img_pil)
plt.show()
print(info_dic)
print('label', info_dic['label'])
{'label': 1, 'width': 213, 'height': 160, 'info': {'jfif': 257, 'jfif_version': [1, 1], 'jfif_unit': 0, 'jfif_density': [1, 1]}, 'format': 'JPEG', 'format_description': 'JPEG (ISO 10918)', 'category name': 'n02102040', 'ext': '.JPEG', 'file id': 'n02102040_3699', 'filesize': 6814}
label 1
5.6. shard作成方法6
ではもっと実用的な例を紹介します.semantic segmentation用のcamvidを使って,実画像とラベル画像のペアを1サンプルとして保存して見ます.
filename = download('https://s3.amazonaws.com/fast-ai-imagelocal/camvid.tgz')
tar_xzvf(filename)
展開したディレクトリ構造は以下のようになっています.なっています.images/
とlabels/
に,対応するファイルがペアになって保存されています.
camvid
├── images
│ ├── 0001TP_006690.png
│ ├── 0001TP_006720.png
│ ...
└── labels
├── 0001TP_006690_P.png
├── 0001TP_006720_P.png
...
ではファイルリストをそれぞれのディレクトリから取得し,ソートします.この場合にはこれで対応が取れます.
import os
dataset_root_images = './camvid/images/'
dataset_root_labels = './camvid/labels/'
image_paths = sorted([
path for path in Path(dataset_root_images).glob('*')
if not path.is_dir()
and path.name.endswith((
'.PNG', '.png',
))
])
label_paths = sorted([
path for path in Path(dataset_root_labels).glob('*')
if not path.is_dir()
and path.name.endswith((
'.PNG', '.png',
))
])
print(image_paths[:2])
print(label_paths[:2])
[PosixPath('camvid/images/0001TP_006690.png'), PosixPath('camvid/images/0001TP_006720.png')]
[PosixPath('camvid/labels/0001TP_006690_P.png'), PosixPath('camvid/labels/0001TP_006720_P.png')]
zipでペアにして,それからシャッフルします.
path_pair = list(zip(image_paths, label_paths))
random.shuffle(path_pair)
print(path_pair[:2])
[(PosixPath('camvid/images/0006R0_f02070.png'), PosixPath('camvid/labels/0006R0_f02070_P.png')), (PosixPath('camvid/images/0016E5_07050.png'), PosixPath('camvid/labels/0016E5_07050_P.png'))]
ではこれまでと同様にshardに書き込みます.実画像もラベル画像もpngなのですが,実画像はjpegのほうが小さくなるので,jpegでも保存しておきます(この場合には実画像のpngを保存する必要はないのですが,例題としてpngも一緒に保存しています).
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import io
shard_path = './shards_06_camvid'
shard_dir_path = Path(shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / 'shards-%05d.tar')
print('shards are saved as', shard_filename)
shard_size = int(150 * 1000**2) # 150MB each
with wds.ShardWriter(
shard_filename,
maxsize=shard_size,
) as sink, tqdm(
path_pair,
total=len(image_paths)
) as pbar:
for pair in pbar:
img_path, label_path = pair
with open(img_path, 'rb') as raw_bytes:
img_buffer = raw_bytes.read()
with open(label_path, 'rb') as raw_bytes:
label_buffer = raw_bytes.read()
sink.write({
"__key__": img_path.stem,
"img.png": img_buffer,
"img.jpg": np.array(Image.open(img_path)),
"label.png": label_buffer,
"pair.pickle": (img_buffer, label_buffer)
})
dataset_size = len(shard_filename)
dataset_size_filename = str(
shard_dir_path / 'dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
"dataset size": dataset_size,
"n_classes": len(category_index),
}, fp)
shards are saved as shards_06_camvid/shards-%05d.tar
# writing shards_06_camvid/shards-00000.tar 0 0.0 GB 0
0%| | 0/701 [00:00<?, ?it/s]
# writing shards_06_camvid/shards-00001.tar 72 0.2 GB 72
# writing shards_06_camvid/shards-00002.tar 72 0.2 GB 144
# writing shards_06_camvid/shards-00003.tar 72 0.2 GB 216
# writing shards_06_camvid/shards-00004.tar 72 0.2 GB 288
# writing shards_06_camvid/shards-00005.tar 72 0.2 GB 360
# writing shards_06_camvid/shards-00006.tar 71 0.2 GB 431
# writing shards_06_camvid/shards-00007.tar 71 0.2 GB 502
# writing shards_06_camvid/shards-00008.tar 70 0.2 GB 572
# writing shards_06_camvid/shards-00009.tar 72 0.2 GB 644
tar_tvf('shards_06_camvid/shards-00000.tar')
bigdata/bigdata 457593 0006R0_f02070.img.jpg
bigdata/bigdata 903221 0006R0_f02070.img.png
bigdata/bigdata 28395 0006R0_f02070.label.png
bigdata/bigdata 931636 0006R0_f02070.pair.pickle
bigdata/bigdata 413613 0016E5_07050.img.jpg
bigdata/bigdata 867848 0016E5_07050.img.png
bigdata/bigdata 16780 0016E5_07050.label.png
bigdata/bigdata 884648 0016E5_07050.pair.pickle
bigdata/bigdata 333600 0001TP_009450.img.jpg
bigdata/bigdata 724877 0001TP_009450.img.png
import matplotlib.pyplot as plt
%matplotlib inline
import pickle
shard_filename = 'shards_06_camvid/shards-00000.tar'
with tarfile.open(shard_filename, 'r') as tar_file:
filename = tar_file.next().name
key = filename.split('.')[0] # stem
img_png = Image.open(tar_file.extractfile(key + '.img.png'))
img_jpeg = Image.open(tar_file.extractfile(key + '.img.jpg'))
label = Image.open(tar_file.extractfile(key + '.label.png'))
plt.imshow(np.array(img_png))
print(img_png.format, img_png.size)
plt.show()
plt.imshow(np.array(img_jpeg))
print(img_jpeg.format, img_jpeg.size)
plt.show()
plt.imshow(np.array(label))
print(label.format, label.size)
plt.show()
PNG (960, 720)
JPEG (960, 720)
PNG (960, 720)
6. webdatasetのshardの読み込み方1
ではshardをwebdatasetのローダーで読み込みます.
6.1. shardの読み込み:よくあるサンプル
まずはシンプルな読み込み方法です.
import webdataset as wds
from torch.utils.data import DataLoader
import numpy as np
shard_pattern = 'shards_01/shards-{00000..00003}.tar'
dataset = wds.WebDataset(shard_pattern).decode('pil').to_tuple('jpg', 'cls').map_tuple(
lambda pil_img: np.array(pil_img.resize((224, 224))),
lambda label: label,
)
dataloader = DataLoader(
dataset,
batch_size=4)
for img, label in dataloader:
print(img.shape, label)
break
torch.Size([4, 224, 224, 3]) tensor([1, 7, 9, 0])
shard_pattern
は特殊な書き方ですが,ローカルファイルではなくweb上のURLでも同じように扱えるようにこの形式(braceexpand)になっています.
以下のように普通のファイルリストも受け付けます.
上記のコードは1行でdataset
を作っています.このように処理を連結して,処理のパイプラインを作成しています.作成時はまだ処理は行われずにパイプラインが生成されるだけであり,実際にデータを読み込む時点で,パイプラインに設定した各処理が順番に実行されます.
実際に使うことはありませんが,パイプラインの中身であるdataset.pipline
を見ると,順番にリストに処理が追加されている事がわかります.
print('length of the pipeline: ', len(dataset.pipeline))
print()
for i, p in enumerate(dataset.pipeline):
print(i, p)
length of the pipeline: 7
0 <webdataset.shardlists.SimpleShardList object at 0x7f1d4d748650>
1 <function single_node_only at 0x7f1d52f2a050>
2 <function split_by_worker at 0x7f1d52f2a0e0>
3 <tarfile_samples () {'handler': <function reraise_exception at 0x7f1d52f8c560>}>
4 <_map (<webdataset.autodecode.Decoder object at 0x7f1d4d748790>,) {'handler': <function reraise_exception at 0x7f1d52f8c560>}>
5 <_to_tuple ('jpg', 'cls') {'handler': <function reraise_exception at 0x7f1d52f8c560>}>
6 <_map_tuple (<function <lambda> at 0x7f1dc09ed050>, <function <lambda> at 0x7f1d4d6dcb00>) {'handler': <function reraise_exception at 0x7f1d52f8c560>}>
6.1.1. 改行する方法
上記のコードは1行でdataset
を作っていますが,以下のように書いても同じことです.1行が長くなる場合には,改行するか,dataset
に代入を繰り返してもよいです.
import webdataset as wds
from torch.utils.data import DataLoader
import numpy as np
shards_list = [
'shards_01/shards-00000.tar',
'shards_01/shards-00001.tar',
'shards_01/shards-00002.tar',
'shards_01/shards-00003.tar',
]
dataset = (
wds.WebDataset(shards_list)
.decode('pil')
.to_tuple('jpg', 'cls')
.map_tuple(
lambda pil_img: np.array(pil_img.resize((224, 224))),
lambda label: label,
))
dataloader = DataLoader(
dataset,
batch_size=4)
for img, label in dataloader:
print(img.shape, label)
break
torch.Size([4, 224, 224, 3]) tensor([1, 7, 9, 0])
import webdataset as wds
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
shard_dir = 'shards_01'
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
dataset = dataset.to_tuple('jpg', 'cls')
dataset = dataset.map_tuple(
lambda pil_img: np.array(pil_img.resize((224, 224))),
lambda label: label
)
dataloader = DataLoader(
dataset,
batch_size=4)
for img, label in dataloader:
print(img.shape, label)
break
torch.Size([4, 224, 224, 3]) tensor([1, 7, 9, 0])
6.1.2. datasetへ繰り返し代入する方法
以下ではdataset
に代入を繰り返す方法で説明します.そのほうが各ステップで何が起こっているのかがわかりやすいので.
まずは,dataset
をDataloader
で読み込むだけ,他の処理はいっさい省てみます.
dataset = wds.WebDataset(shards_list)
dataloader = DataLoader(
dataset,
batch_size=4)
for sample in dataloader:
print(type(sample))
print(sample)
break
<class 'dict'>
{'__key__': ['n02102040/n02102040_3699', 'n03425413/n03425413_12249', 'n03888257/n03888257_6468', 'n01440764/n01440764_769'], '__url__': ['shards_01/shards-00000.tar', 'shards_01/shards-00000.tar', 'shards_01/shards-00000.tar', 'shards_01/shards-00000.tar'], 'cls': [b'1', b'7', b'9', b'0'], 'jpg': [b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb
つまりバッチでdict型が返ってくることがわかりました.
6.2. dataloaderを使わない例
以下ではDataloader
を使わずに,dataloader
から直接読み込んでみましょう.
import pprint
pp = pprint.PrettyPrinter(indent=4)
dataset = wds.WebDataset(shards_list)
for sample in dataset:
print(type(sample))
sample['jpg'] = sample['jpg'][:30] # 出力が長いのでこれだけ省略
pp.pprint(sample)
break
<class 'dict'>
{ '__key__': 'n02102040/n02102040_3699',
'__url__': 'shards_01/shards-00000.tar',
'cls': b'1',
'jpg': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01'
b'\x00\x01\x00\x00\xff\xdb\x00C\x00\x01\x01\x01\x01\x01'}
これで1サンプル分のdictの内容がわかりました.
-
__key__
にstemの部分(つまりキー)が入っている -
__url__
にはtarファイル名が入っている -
cls
にはバイナリで数値が入っている -
jpg
にはバイナリでJPEGファイルのバイナリ列が入っている
6.2.1. PIL.Imageへのデコード
ではデコードしましょう.
dataset
のパイプラインの途中に
dataset.decode('pil')
を入れると,dictのキーとして画像の拡張子(jpeg, jpg, pngなど)が入っていたら,PIL.Image
型に変換します.
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
for sample in dataset:
print(type(sample))
pp.pprint(sample)
break
<class 'dict'>
{ '__key__': 'n02102040/n02102040_3699',
'__url__': 'shards_01/shards-00000.tar',
'cls': 1,
'jpg': <PIL.Image.Image image mode=RGB size=213x160 at 0x7F1D4D748A90>}
これで,dict型のキーjpg
の値がPIL.Image
になりました.
6.2.2. ndarrayへのデコード
pil
以外にも,ndarrayやtensorにデコードすることもできます.
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('rgb') # これで(W, H, 3)のndarray
for sample in dataset:
print(type(sample))
print(sample['jpg'].shape)
pp.pprint(sample)
break
<class 'dict'>
(160, 213, 3)
{ '__key__': 'n02102040/n02102040_3699',
'__url__': 'shards_01/shards-00000.tar',
'cls': 1,
'jpg': array([[[0.31764707, 0.2901961 , 0.22745098],
[0.32156864, 0.29411766, 0.23137255],
[0.3254902 , 0.29803923, 0.23529412],
...,
[0.11764706, 0.08627451, 0.07450981],
[0.1254902 , 0.09411765, 0.08235294],
[0.13725491, 0.10588235, 0.09411765]],
[[0.32156864, 0.29411766, 0.23137255],
[0.32156864, 0.29411766, 0.23137255],
[0.3254902 , 0.29803923, 0.23529412],
...,
[0.11764706, 0.08627451, 0.07450981],
[0.11764706, 0.08627451, 0.07450981],
[0.12941177, 0.09803922, 0.08627451]],
[[0.3254902 , 0.29803923, 0.23529412],
[0.32941177, 0.3019608 , 0.23921569],
[0.3372549 , 0.30980393, 0.24705882],
...,
[0.11764706, 0.08627451, 0.07450981],
[0.11764706, 0.08627451, 0.07450981],
[0.11764706, 0.08627451, 0.07450981]],
...,
[[0.30980393, 0.29803923, 0.2627451 ],
[0.20392157, 0.19215687, 0.15686275],
[0.2509804 , 0.23529412, 0.2 ],
...,
[0.54509807, 0.34117648, 0.10588235],
[0.4745098 , 0.27058825, 0.03529412],
[0.5882353 , 0.38431373, 0.14901961]],
[[0.45490196, 0.45490196, 0.41568628],
[0.4 , 0.3882353 , 0.3529412 ],
[0.49411765, 0.47843137, 0.44313726],
...,
[0.6313726 , 0.42745098, 0.19215687],
[0.5411765 , 0.3372549 , 0.10196079],
[0.5254902 , 0.32156864, 0.08627451]],
[[0.49803922, 0.49803922, 0.46666667],
[0.3647059 , 0.3529412 , 0.3254902 ],
[0.40784314, 0.3882353 , 0.3647059 ],
...,
[0.6862745 , 0.48235294, 0.24705882],
[0.627451 , 0.42352942, 0.1882353 ],
[0.49803922, 0.29411766, 0.05882353]]], dtype=float32)}
6.2.3. torch.tensorへのデコード
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('torch') # これで(3, H, W)のtorch.tensorでdtypeはfloat
for sample in dataset:
print(type(sample))
print(sample['jpg'].size())
pp.pprint(sample)
break
<class 'dict'>
torch.Size([3, 160, 213])
{ '__key__': 'n02102040/n02102040_3699',
'__url__': 'shards_01/shards-00000.tar',
'cls': 1,
'jpg': tensor([[[0.3176, 0.3216, 0.3255, ..., 0.1176, 0.1255, 0.1373],
[0.3216, 0.3216, 0.3255, ..., 0.1176, 0.1176, 0.1294],
[0.3255, 0.3294, 0.3373, ..., 0.1176, 0.1176, 0.1176],
...,
[0.3098, 0.2039, 0.2510, ..., 0.5451, 0.4745, 0.5882],
[0.4549, 0.4000, 0.4941, ..., 0.6314, 0.5412, 0.5255],
[0.4980, 0.3647, 0.4078, ..., 0.6863, 0.6275, 0.4980]],
[[0.2902, 0.2941, 0.2980, ..., 0.0863, 0.0941, 0.1059],
[0.2941, 0.2941, 0.2980, ..., 0.0863, 0.0863, 0.0980],
[0.2980, 0.3020, 0.3098, ..., 0.0863, 0.0863, 0.0863],
...,
[0.2980, 0.1922, 0.2353, ..., 0.3412, 0.2706, 0.3843],
[0.4549, 0.3882, 0.4784, ..., 0.4275, 0.3373, 0.3216],
[0.4980, 0.3529, 0.3882, ..., 0.4824, 0.4235, 0.2941]],
[[0.2275, 0.2314, 0.2353, ..., 0.0745, 0.0824, 0.0941],
[0.2314, 0.2314, 0.2353, ..., 0.0745, 0.0745, 0.0863],
[0.2353, 0.2392, 0.2471, ..., 0.0745, 0.0745, 0.0745],
...,
[0.2627, 0.1569, 0.2000, ..., 0.1059, 0.0353, 0.1490],
[0.4157, 0.3529, 0.4431, ..., 0.1922, 0.1020, 0.0863],
[0.4667, 0.3255, 0.3647, ..., 0.2471, 0.1882, 0.0588]]])}
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('torchrgb8') # これで(3, H, W)のtorch.tensorでdtypeはuint8
for sample in dataset:
print(type(sample))
print(sample['jpg'].size())
pp.pprint(sample)
break
<class 'dict'>
torch.Size([3, 160, 213])
{ '__key__': 'n02102040/n02102040_3699',
'__url__': 'shards_01/shards-00000.tar',
'cls': 1,
'jpg': tensor([[[ 81, 82, 83, ..., 30, 32, 35],
[ 82, 82, 83, ..., 30, 30, 33],
[ 83, 84, 86, ..., 30, 30, 30],
...,
[ 79, 52, 64, ..., 139, 121, 150],
[116, 102, 126, ..., 161, 138, 134],
[127, 93, 104, ..., 175, 160, 127]],
[[ 74, 75, 76, ..., 22, 24, 27],
[ 75, 75, 76, ..., 22, 22, 25],
[ 76, 77, 79, ..., 22, 22, 22],
...,
[ 76, 49, 60, ..., 87, 69, 98],
[116, 99, 122, ..., 109, 86, 82],
[127, 90, 99, ..., 123, 108, 75]],
[[ 58, 59, 60, ..., 19, 21, 24],
[ 59, 59, 60, ..., 19, 19, 22],
[ 60, 61, 63, ..., 19, 19, 19],
...,
[ 67, 40, 51, ..., 27, 9, 38],
[106, 90, 113, ..., 49, 26, 22],
[119, 83, 93, ..., 63, 48, 15]]], dtype=torch.uint8)}
6.3. 自動デコードの種類
画像をデコードできる種類はコードを見ると以下のようになっています.
imagespecs = {
"l8": ("numpy", "uint8", "l"),
"rgb8": ("numpy", "uint8", "rgb"),
"rgba8": ("numpy", "uint8", "rgba"),
"l": ("numpy", "float", "l"),
"rgb": ("numpy", "float", "rgb"),
"rgba": ("numpy", "float", "rgba"),
"torchl8": ("torch", "uint8", "l"),
"torchrgb8": ("torch", "uint8", "rgb"),
"torchrgba8": ("torch", "uint8", "rgba"),
"torchl": ("torch", "float", "l"),
"torchrgb": ("torch", "float", "rgb"),
"torch": ("torch", "float", "rgb"),
"torchrgba": ("torch", "float", "rgba"),
"pill": ("pil", None, "l"),
"pil": ("pil", None, "rgb"),
"pilrgb": ("pil", None, "rgb"),
"pilrgba": ("pil", None, "rgba"),
}
上記の例ではcls
キーもバイナリから数値に変換されていました.画像以外のキーは,拡張子に基づいて以下のように自動的にデコードされるようです.
decoders = {
"txt": lambda data: data.decode("utf-8"),
"text": lambda data: data.decode("utf-8"),
"transcript": lambda data: data.decode("utf-8"),
"cls": lambda data: int(data),
"cls2": lambda data: int(data),
"index": lambda data: int(data),
"inx": lambda data: int(data),
"id": lambda data: int(data),
"json": lambda data: json.loads(data),
"jsn": lambda data: json.loads(data),
"pyd": lambda data: pickle.loads(data),
"pickle": lambda data: pickle.loads(data),
"pth": lambda data: torch_loads(data),
"ten": tenbin_loads,
"tb": tenbin_loads,
"mp": msgpack_loads,
"msg": msgpack_loads,
"npy": npy_loads,
"npz": lambda data: np.load(io.BytesIO(data)),
"cbor": cbor_loads,
}
一覧にない拡張子の場合には,何もデコードしないようです.
6.4. dictからタプルへの変換
では値をデコードしたdictを,タプルに変換しましょう.(なぜタプルに変換するかというと,webdatasetのサンプルコードでそうなっているため.dictのままでも以後の処理はできます).
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
dataset = dataset.to_tuple('jpg', 'cls')
for sample in dataset:
print(type(sample))
print(sample)
break
<class 'tuple'>
(<PIL.Image.Image image mode=RGB size=213x160 at 0x7F1DC74C6E90>, 1)
これでタプルになっているのがわかります.
to_tuple()
で,dictの中から必要なキーを指定して,タプルにします.
to_tuple()
にはキーのリストを与えてもいいですし,webdatasetのサンプルのように空白で区切った1つの文字列でもよいです.
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
dataset = dataset.to_tuple('jpg cls')
for sample in dataset:
print(type(sample))
print(sample)
break
<class 'tuple'>
(<PIL.Image.Image image mode=RGB size=213x160 at 0x7F1D51F08E10>, 1)
キーには当然__key__
と__url__
も使えます.
また同じキーを複数回指定することもできます(何に使うかはともかく)
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
dataset = dataset.to_tuple('cls', '__key__', 'jpg', 'cls')
for sample in dataset:
print(type(sample))
print(sample)
break
<class 'tuple'>
(1, 'n02102040/n02102040_3699', <PIL.Image.Image image mode=RGB size=213x160 at 0x7F1D4D6AE490>, 1)
6.5. mapの適用:map_tuple()
そして,map_tuple()
で,タプルのそれぞれの要素に適用する関数を指定します.
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
dataset = dataset.to_tuple('jpg', 'cls')
dataset = dataset.map_tuple(
lambda pil_img: pil_img.resize((224, 224)),
lambda label: label
)
for sample in dataset:
print(type(sample))
print(sample)
break
<class 'tuple'>
(<PIL.Image.Image image mode=RGB size=224x224 at 0x7F1D4D6AE1D0>, 1)
上の例ではlambda
関数を与えましたが,当然普通の関数も使えます.
def resize_img(x):
return x.resize((224, 224)),
def identity(x):
return x
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
dataset = dataset.to_tuple('jpg', 'cls')
dataset = dataset.map_tuple(resize_img, identity)
for sample in dataset:
print(type(sample))
print(sample)
break
<class 'tuple'>
((<PIL.Image.Image image mode=RGB size=224x224 at 0x7F1D51F08190>,), 1)
6.5.1. functools.partialを使った例
webdatasetのサンプルコードではfunctools.partial
が使われています.引数をもつ関数を引数に与える場合に便利です.
以下の例はラベルの方にも手を加えてみたものです.
from functools import partial
def resize_img(x, shape):
return x.resize(shape),
def label_transform(x):
return 'label:' + str(x)
resize_func = partial(
resize_img,
shape=(224, 224))
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
dataset = dataset.to_tuple('jpg', 'cls')
dataset = dataset.map_tuple(resize_func, label_transform)
for sample in dataset:
print(type(sample))
print(sample)
break
<class 'tuple'>
((<PIL.Image.Image image mode=RGB size=224x224 at 0x7F1D4D6D0790>,), 'label:1')
6.5.2. map_dict()の使用例
ドキュメントにはつかい方がのっていませんが,コードを見るとタプルにせずともdictにそのままmap
を適用するmap_dict
が以下のように使えます.
from functools import partial
def resize_img(x, shape):
return x.resize(shape),
def label_transform(x):
return 'label:' + str(x)
resize_func = partial(
resize_img,
shape=(224, 224))
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('pil')
dataset = dataset.map_dict(
jpg=resize_func,
cls=label_transform
)
for sample in dataset:
print(type(sample))
pp.pprint(sample)
break
<class 'dict'>
{ '__key__': 'n02102040/n02102040_3699',
'__url__': 'shards_01/shards-00000.tar',
'cls': 'label:1',
'jpg': (<PIL.Image.Image image mode=RGB size=224x224 at 0x7F1D521A5E50>,)}
dictのまま扱うことになにかメリットがあれば有用でしょう.通常はto_tuple
で必要なキーだけを抜き出してタプルにすればよいです.
6.6. 自作デコーダの作成
6.6.1. デコーダで処理
キーが画像の拡張子やデフォルトで受け付けるもの以外をデコードしたい場合には,自分でデコーダを作成します.そのためには,decode()
にwds.handle_extension
で自作関数を以下のように与えます.
from PIL import Image
import io
import numpy as np
def my_jpg_decoder(x):
img = Image.open(io.BytesIO(x))
img = img.resize((224, 224))
return img
def my_label_decoder(x):
return 'label decode: ' + str(int(x)) # byte->int->str
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode(
wds.handle_extension('jpg', my_jpg_decoder),
wds.handle_extension('cls', my_label_decoder),
)
for sample in dataset:
print(type(sample))
pp.pprint(sample)
break
<class 'dict'>
{ '__key__': 'n02102040/n02102040_3699',
'__url__': 'shards_01/shards-00000.tar',
'cls': 'label decode: 1',
'jpg': <PIL.Image.Image image mode=RGB size=224x224 at 0x7F1D52ADAD90>}
6.6.2. デコーダとmapで処理
ではデコードとmapを両方使ってみましょう.
以下は自作デコーダで画像をPIL.Image
にして,mapのほうでリサイズしてndarrayに変換しています.
from PIL import Image
import io
import numpy as np
def my_jpg_decoder(x):
img = Image.open(io.BytesIO(x))
return img
def my_label_decoder(x):
return 'label decode: ' + str(int(x))
def resize_img(x):
return np.array(x.resize((224, 224)))
def label_transform(x):
return 'label transform: ' + str(x)
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode(
wds.handle_extension('jpg', my_jpg_decoder),
wds.handle_extension('cls', my_label_decoder),
)
dataset = dataset.to_tuple('jpg', 'cls')
dataset = dataset.map_tuple(resize_img, label_transform)
for sample in dataset:
print(type(sample))
print('shape: ', sample[0].shape)
print('max: ', sample[0].max())
print(sample[1])
break
<class 'tuple'>
shape: (224, 224, 3)
max: 253
label transform: label decode: 1
6.6.3. transformの適用
ではもっと実践的な例として,torch.tensorにtransformを適用するという,よくある例を考えてみます.
上記までの例で分かるように,shardから取り出したbyteデータを処理するには,decodeとmapの2箇所があることがわかります.どこで何を処理するのかは設計次第です(decodeでデコードとtransformまでやってしまっても良いし,decodeではデコードだけしてmapでtransformを適用する,など).
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
RandomResizedCrop,
Normalize,
RandomHorizontalFlip,
Compose,
)
from torch import frombuffer
import torch
from functools import partial
def get_transform():
transform_list = [
RandomResizedCrop(224),
RandomHorizontalFlip(),
transforms.Lambda(lambda x: x / 255.),
Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
transform = Compose(transform_list)
return transform
def my_decoder(x, device):
x = decode_jpeg(
frombuffer(x, dtype=torch.uint8),
device=device)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
decoder_func = partial(
my_decoder,
device=device
)
def transform_img(x, transform):
return transform(x)
transform_func = partial(
transform_img,
transform=get_transform()
)
def identity(x):
return x
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode(
wds.handle_extension('jpg', decoder_func),
)
dataset = dataset.to_tuple('jpg', 'cls')
dataset = dataset.map_tuple(transform_func, identity, identity)
for sample in dataset:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('device: ', sample[0].device)
print('max: ', sample[0].max())
print(sample[1])
break
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:29: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:1105.)
<class 'tuple'>
<class 'torch.Tensor'>
shape: torch.Size([3, 224, 224])
device: cuda:0
max: tensor(1.4098, device='cuda:0')
1
上記のコードの説明です.
- デコーダではbyte型の
x
をfrombuffer
でtorch.tensorにしてから,torchvisionのdecode_jpeg
を使っています(同時にdeviceに送っています). - mapではtransformを適用しています.すでにtensor型なので,よくある
ToTensor
は不要で,255で割る処理だけが必要です.
6.7. dataloaderでバッチ取得
ではようやくDalaloader
を使ってバッチを取得するところまで来ました.webdatasetオブジェクトの作成は全部一つの関数の中に入れ込んでしまいましょう
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
RandomResizedCrop,
Normalize,
RandomHorizontalFlip,
Compose,
)
from torch import frombuffer
import torch
from functools import partial
from torch.utils.data import DataLoader
def get_dataset(shard_dir, device):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
def get_transform():
transform_list = [
RandomResizedCrop(224),
RandomHorizontalFlip(),
transforms.Lambda(lambda x: x / 255.),
Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
transform = Compose(transform_list)
return transform
def my_decoder(x, device):
x = decode_jpeg(
frombuffer(x, dtype=torch.uint8),
device=device)
return x
decoder_func = partial(
my_decoder,
device=device
)
def transform_img(x, transform):
return transform(x)
transform_func = partial(
transform_img,
transform=get_transform()
)
def identity(x):
return x
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode(
wds.handle_extension('jpg', decoder_func),
)
dataset = dataset.to_tuple('jpg', 'cls')
dataset = dataset.map_tuple(transform_func, identity)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
get_dataset(
shard_dir='shards_01',
device=device),
batch_size=4)
for sample in dataloader:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('device: ', sample[0].device)
print('max: ', sample[0].max())
print(sample[1])
break
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
device: cuda:0
max: tensor(2.6226, device='cuda:0')
tensor([1, 7, 9, 0])
なお,これまではdataset
データセットサイズを指定していませんでした.shardファイル作成時にjsonファイルにサイズを保存しておいたので,上記のコードではそれを読み込んでいます.データセットのサイズを設定するために,パイプラインの最後でdataset.with_length(dataset_size)
で設定しています.
7. wabdatasetのshardの読み込み方2
shard_dir
を変えてもshardの中身が同じなら,変更なしで動きます.
dataloader = DataLoader(
get_dataset(
shard_dir='shards_02',
device=device),
batch_size=4)
for sample in dataloader:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('device: ', sample[0].device)
print('max: ', sample[0].max())
print(sample[1])
break
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
device: cuda:0
max: tensor(2.6400, device='cuda:0')
tensor([1, 7, 9, 0])
8. wabdatasetのshardの読み込み方3
ではshard構造が異なるものを扱ってみましょう.
データがcls
キーではなくjson
になっている場合には,自動的にdictにデコードされます.したがって変更は,to_tuple
でcls
からjson
にするだけです.
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
RandomResizedCrop,
Normalize,
RandomHorizontalFlip,
Compose,
)
from torch import frombuffer
import torch
from functools import partial
from torch.utils.data import DataLoader
import json
def get_dataset(shard_dir, device):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
def get_transform():
transform_list = [
RandomResizedCrop(224),
RandomHorizontalFlip(),
transforms.Lambda(lambda x: x / 255.),
Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
transform = Compose(transform_list)
return transform
def my_decoder(x, device):
x = decode_jpeg(
frombuffer(x, dtype=torch.uint8),
device=device)
return x
decoder_func = partial(
my_decoder,
device=device
)
def transform_img(x, transform):
return transform(x)
transform_func = partial(
transform_img,
transform=get_transform()
)
def identity(x):
return x
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode(
wds.handle_extension('jpg', decoder_func),
)
dataset = dataset.to_tuple('jpg', 'json')
dataset = dataset.map_tuple(transform_func, identity)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
get_dataset(
shard_dir='shards_03',
device=device),
batch_size=4)
for sample in dataloader:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('device: ', sample[0].device)
print(sample[1])
print('max: ', sample[0].max())
print('label: ', sample[1]['label'])
break
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
device: cuda:0
{'label': tensor([1, 7, 9, 0]), 'width': tensor([213, 213, 238, 213]), 'height': tensor([160, 160, 160, 160]), 'info': {'jfif': tensor([257, 257, 257, 257]), 'jfif_version': [tensor([1, 1, 1, 1]), tensor([1, 1, 1, 1])], 'jfif_unit': tensor([0, 0, 0, 0]), 'jfif_density': [tensor([1, 1, 1, 1]), tensor([1, 1, 1, 1])]}, 'format': ['JPEG', 'JPEG', 'JPEG', 'JPEG'], 'format_description': ['JPEG (ISO 10918)', 'JPEG (ISO 10918)', 'JPEG (ISO 10918)', 'JPEG (ISO 10918)'], 'category name': ['n02102040', 'n03425413', 'n03888257', 'n01440764'], 'ext': ['.JPEG', '.JPEG', '.JPEG', '.JPEG'], 'file id': ['n02102040_3699', 'n03425413_12249', 'n03888257_6468', 'n01440764_769'], 'filesize': tensor([ 6814, 9479, 7363, 10216])}
max: tensor(2.1520, device='cuda:0')
label: tensor([1, 7, 9, 0])
9. wabdatasetのshardの読み込み方4
データがpickle
キーだけの場合には,デコードは自動処理にまかせて(decode()
には引数がないですが自動デコードできるものはします),自分でmapで処理します.
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
RandomResizedCrop,
Normalize,
RandomHorizontalFlip,
Compose,
)
from torch import frombuffer
import torch
from functools import partial
from torch.utils.data import DataLoader
import json
def get_dataset(shard_dir, device):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
def get_transform():
transform_list = [
RandomResizedCrop(224),
RandomHorizontalFlip(),
transforms.Lambda(lambda x: x / 255.),
Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
transform = Compose(transform_list)
return transform
def transform_pickle(x, transform, device):
buffer, pil_img, path, info_dic = x
tensor_img = decode_jpeg(
frombuffer(buffer, dtype=torch.uint8),
device=device)
tensor_img = transform(tensor_img)
label = info_dic['label']
return tensor_img, label
transform_func = partial(
transform_pickle,
transform=get_transform(),
device=device
)
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode()
dataset = dataset.to_tuple('pickle')
dataset = dataset.map_tuple(transform_func)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
get_dataset(
shard_dir='shards_04',
device=device),
batch_size=4)
for sample_list in dataloader:
sample = sample_list[0]
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('device: ', sample[0].device)
print('max: ', sample[0].max())
print(sample[1])
break
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
device: cuda:0
max: tensor(2.6400, device='cuda:0')
tensor([2, 4, 8, 2])
ちなみにキーがpickle
一つだけが理由なのか,collateがうまく処理されません.そこで,dataloaderで出てくる要素の0を取り出すアドホックな処理を入れています.
もっときれいにやるなら,自作のcollate関数を作って以下のようにすればOKです(どのようなcollate関数を作るのかはいろいろ試すしかない...)
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
RandomResizedCrop,
Normalize,
RandomHorizontalFlip,
Compose,
)
from torch import frombuffer
import torch
from functools import partial
from torch.utils.data import DataLoader
import json
def get_dataset(shard_dir, device):
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
def get_transform():
transform_list = [
RandomResizedCrop(224),
RandomHorizontalFlip(),
transforms.Lambda(lambda x: x / 255.),
Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
transform = Compose(transform_list)
return transform
def transform_pickle(x, transform, device):
buffer, pil_img, path, info_dic = x
tensor_img = decode_jpeg(
frombuffer(buffer, dtype=torch.uint8),
device=device)
tensor_img = transform(tensor_img)
label = info_dic['label']
return tensor_img, label
transform_func = partial(
transform_pickle,
transform=get_transform(),
device=device
)
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode()
dataset = dataset.to_tuple('pickle')
dataset = dataset.map_tuple(transform_func)
return dataset
def my_collate(batch):
return torch.utils.data.default_collate([b[0] for b in batch])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
get_dataset(
shard_dir='shards_04',
device=device),
batch_size=4,
collate_fn=my_collate)
for sample in dataloader:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('device: ', sample[0].device)
print('max: ', sample[0].max())
print(sample[1])
break
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
device: cuda:0
max: tensor(2.6400, device='cuda:0')
tensor([2, 4, 8, 2])
10. wabdatasetのshardの読み込み方5
複数のpickle
キーから画像をデコードする例です.
shard作成時には以下のようにしています.
sink.write({
"__key__": key_str,
"jpg": buffer,
"img.pickle": img,
"path.pickle": path,
"json": json.dumps(info_dic),
})
以下はjpg
キーのbufferからデコードして,jsonからラベルを取り出すだけの処理です.
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
RandomResizedCrop,
Normalize,
RandomHorizontalFlip,
Compose,
)
from torch import frombuffer
import torch
from functools import partial
from torch.utils.data import DataLoader
import json
def get_dataset(shard_dir, device):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
def get_transform():
transform_list = [
RandomResizedCrop(224),
RandomHorizontalFlip(),
transforms.Lambda(lambda x: x / 255.),
Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
transform = Compose(transform_list)
return transform
def my_decoder(x, device):
x = decode_jpeg(
frombuffer(x, dtype=torch.uint8),
device=device)
return x
decoder_func = partial(
my_decoder,
device=device
)
def transform_img(x, transform):
return transform(x)
transform_func = partial(
transform_img,
transform=get_transform()
)
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode(
wds.handle_extension('jpg', decoder_func),
)
dataset = dataset.to_tuple('jpg', 'json')
dataset = dataset.map_tuple(
transform_func,
lambda x: x['label']
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
get_dataset(
shard_dir='shards_05',
device=device),
batch_size=4)
for sample in dataloader:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('device: ', sample[0].device)
print('max: ', sample[0].max())
print(sample[1])
break
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
device: cuda:0
max: tensor(2.6400, device='cuda:0')
tensor([4, 2, 4, 8])
jpg
キーの画像のデコードをおまかせするなら以下のようになります.デコードにtorch
を指定してtensorにすればOKです.ただしこのデコードでfloat32になるので,255で割る処理は不要です.
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
RandomResizedCrop,
Normalize,
RandomHorizontalFlip,
Compose,
)
from torch import frombuffer
import torch
from functools import partial
from torch.utils.data import DataLoader
import json
def get_dataset(shard_dir, device):
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
def get_transform():
transform_list = [
RandomResizedCrop(224),
RandomHorizontalFlip(),
# transforms.Lambda(lambda x: x / 255.),
Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
transform = Compose(transform_list)
return transform
def transform_img(x, transform):
return transform(x)
transform_func = partial(
transform_img,
transform=get_transform()
)
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('torch')
dataset = dataset.to_tuple('jpg', 'json')
dataset = dataset.map_tuple(
transform_func,
lambda x: x['label']
)
return dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
get_dataset(
shard_dir='shards_05',
device=device),
batch_size=4)
for sample in dataloader:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('device: ', sample[0].device)
print('max: ', sample[0].max())
print(sample[1])
break
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
device: cpu
max: tensor(2.6400)
tensor([4, 2, 4, 8])
11. wabdatasetのshardの読み込み方6
ではcamvidのセマンティックセグメンテーション用のデータを固めたshardを読み込んでみましょう.
作成時のshardの構造は以下の通り.
sink.write({
"__key__": img_path.stem,
"img.png": img_buffer,
"img.jpg": np.array(Image.open(img_path)),
"label.png": label_buffer,
"pair.pickle": (img_buffer, label_buffer)
})
ここからimg.png
とlabel.png
の2枚の画像を読み込みます.
ラベル画像は整数値なので,デコード指定にはrgbのuint8を指定し,transformで255で割るかどうかを切り替えます.
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
Resize,
Normalize,
RandomHorizontalFlip,
Compose,
)
import torch
from functools import partial
from torch.utils.data import DataLoader
def get_dataset(shard_dir, device):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
def get_transform(is_label=False):
transform_list = [
Resize((224, 224)),
]
if not is_label:
transform_list.extend([
transforms.Lambda(lambda x: x / 255.),
Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
else:
transform_list.append(
transforms.Lambda(lambda x: x.to(torch.float)))
transform = Compose(transform_list)
return transform
def transform_img(x, transform):
return transform(x)
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode('torchrgb8')
dataset = dataset.to_tuple('img.png', 'label.png')
dataset = dataset.map_tuple(
partial(transform_img,
transform=get_transform(is_label=False)),
partial(transform_img,
transform=get_transform(is_label=True)),
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
get_dataset(
shard_dir='shards_06_camvid',
device=device),
batch_size=4)
for sample in dataloader:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('dtype: ', sample[0].dtype)
print('device: ', sample[0].device)
print('to.device: ', sample[0].to(device).device)
print('max: ', sample[0].max())
print(type(sample[1]))
print('shape: ', sample[1].shape)
print('dtype: ', sample[1].dtype)
print('device: ', sample[1].device)
print('to.device: ', sample[1].to(device).device)
print('max: ', sample[1].max())
break
img = sample[0][0].detach().cpu().permute(1,2,0).numpy()
plt.imshow((img - img.min()) / img.max())
plt.show()
label = sample[1][0].detach().cpu().permute(1,2,0).numpy() / 30
plt.imshow(label)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
dtype: torch.float32
device: cpu
to.device: cuda:0
max: tensor(2.6400)
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
dtype: torch.float32
device: cpu
to.device: cuda:0
max: tensor(31.)
しかし上のようなやり方では,imageとlabelに同じ幾何変換を適用することができません.セグメンテーション用のデータ拡張のときに,単純にランダムにresizeやcropや反転を行うと,imageとlabelで別々の画像処理が行われてしまします.
そこでalbumentationsを使います.このalbumentationsのtransformを使うと,imageとlabelの画像に同じ変換を適用することができます.さらにNormalizeやFloat変換などはlabelには適用されないので便利です.
そのためには,「img.png
とlabel.png
という2つのキーから別々に2枚の画像を読み込む」というやり方を変更します.webdatasetのパイプラインでは,別々のキーから読み込んだデータはいつまでも別々のままで一つにすることができません.そのため,以下のようにpickleでimgとlabelを一つにまとめたキーpair.pickle
から読み込むことにします.
またabumentationsはndarrayしか扱えないようなので,デーコーダではnp.array
を使います.
from torchvision.io import decode_jpeg
from torchvision import transforms
from torchvision.transforms import (
Resize,
Normalize,
RandomHorizontalFlip,
Compose,
)
import torch
from functools import partial
from torch.utils.data import DataLoader
import numpy as np
import albumentations as A
import albumentations.pytorch as Ap
def get_dataset(shard_dir, device):
shards_list = [
str(path) for path in Path(shard_dir).glob('*.tar')
]
def get_transform():
transform_list = [
A.RandomResizedCrop(224, 224),
A.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
Ap.ToTensorV2(), # HWC --> CHW
]
transform = A.Compose(transform_list)
return transform
def decode_png(buf):
return np.array(Image.open(io.BytesIO(buf)).convert('RGB'))
def transform_pair(x, transform, device):
img, label = x
img = decode_png(img)
label = decode_png(label)
transformed = transform(image=img, mask=label)
img = torch.as_tensor(
transformed['image'],
device=device)
label = torch.as_tensor(
transformed['mask'],
device=device)
return (
img, # CHW,float32
label.permute(2,0,1).to(torch.float) # HWC,uint8 --> CHW,float32
)
transform_func = partial(
transform_pair,
transform=get_transform(),
device=device
)
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode()
dataset = dataset.to_tuple('pair.pickle')
dataset = dataset.map_tuple(transform_func)
return dataset
def my_collate(batch):
return torch.utils.data.default_collate([b[0] for b in batch])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
get_dataset(
shard_dir='shards_06_camvid',
device=device),
batch_size=4,
collate_fn=my_collate)
for sample in dataloader:
print(type(sample))
print(type(sample[0]))
print('shape: ', sample[0].shape)
print('dtype: ', sample[0].dtype)
print('device: ', sample[0].device)
print('to.device: ', sample[0].to(device).device)
print('max: ', sample[0].max())
print(type(sample[1]))
print('shape: ', sample[1].shape)
print('dtype: ', sample[1].dtype)
print('device: ', sample[1].device)
print('to.device: ', sample[1].to(device).device)
print('max: ', sample[1].max())
break
img = sample[0][0].detach().cpu().permute(1,2,0).numpy()
plt.imshow((img - img.min()) / img.max())
plt.show()
label = sample[1][0].detach().cpu().permute(1,2,0).numpy() / 30
plt.imshow(label)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<class 'list'>
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
dtype: torch.float32
device: cuda:0
to.device: cuda:0
max: tensor(2.6400, device='cuda:0')
<class 'torch.Tensor'>
shape: torch.Size([4, 3, 224, 224])
dtype: torch.float32
device: cuda:0
to.device: cuda:0
max: tensor(31., device='cuda:0')
12. シャッフルの方法
webdatasetは複数のshard(tarファイル)の中から順番にデータを取り出します.そのため,学習に用いるサンプルの順番が毎回(エポックごとに)同じになります.もしshard生成時に,クラス順にデータをtarに固めてしまうと,shardから取り出したサンプルは同じクラスのものが連続してしまうことになります.通常これは学習にとって悪影響を与えます(ランダムにサンプルをとってきて学習に使わなければ,性能が落ちます).
そのためwebdatasetではパイプラインの途中にシャッフルする仕組みがあります.
12.1. テスト用のshard作成
その仕組によるシャッフルの様子を見るために,まずは以下のようなテスト用shardを作ります.
tarファイルの中には,shard番号と,そのtarファイル中のサンプルの番号をもつjsonを作ります.
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import json
from pathlib import Path
shard_path = './shards_07_test'
shard_dir_path = Path(shard_path)
shard_dir_path.mkdir(exist_ok=True)
shard_filename = str(shard_dir_path / 'shards-%05d.tar')
n_samples = 57
max_count = 11
max_size = 1000**2
with wds.ShardWriter(
shard_filename,
maxsize=max_size,
maxcount=max_count,
) as sink, tqdm(
range(n_samples)
) as pbar:
print(sink.maxcount)
for i in pbar:
if sink.count >= sink.maxcount:
shard = sink.shard + 1
count = 0
else:
shard = sink.shard # 1-origin
count = sink.count # 0-origin
sample_dic ={
'__key__': str(i),
'json': json.dumps({
'shard': shard - 1, # 0-origin
'sample': count
})
}
sink.write(sample_dic)
print(sample_dic)
dataset_size = sink.total
dataset_size_filename = str(
shard_dir_path / 'dataset-size.json')
with open(dataset_size_filename, 'w') as fp:
json.dump({
"dataset size": dataset_size,
}, fp)
# writing shards_07_test/shards-00000.tar 0 0.0 GB 0
0%| | 0/57 [00:00<?, ?it/s]
11
{'__key__': '0', 'json': '{"shard": 0, "sample": 0}'}
{'__key__': '1', 'json': '{"shard": 0, "sample": 1}'}
{'__key__': '2', 'json': '{"shard": 0, "sample": 2}'}
{'__key__': '3', 'json': '{"shard": 0, "sample": 3}'}
{'__key__': '4', 'json': '{"shard": 0, "sample": 4}'}
{'__key__': '5', 'json': '{"shard": 0, "sample": 5}'}
{'__key__': '6', 'json': '{"shard": 0, "sample": 6}'}
{'__key__': '7', 'json': '{"shard": 0, "sample": 7}'}
{'__key__': '8', 'json': '{"shard": 0, "sample": 8}'}
{'__key__': '9', 'json': '{"shard": 0, "sample": 9}'}
{'__key__': '10', 'json': '{"shard": 0, "sample": 10}'}
# writing shards_07_test/shards-00001.tar 11 0.0 GB 11
{'__key__': '11', 'json': '{"shard": 1, "sample": 0}'}
{'__key__': '12', 'json': '{"shard": 1, "sample": 1}'}
{'__key__': '13', 'json': '{"shard": 1, "sample": 2}'}
{'__key__': '14', 'json': '{"shard": 1, "sample": 3}'}
{'__key__': '15', 'json': '{"shard": 1, "sample": 4}'}
{'__key__': '16', 'json': '{"shard": 1, "sample": 5}'}
{'__key__': '17', 'json': '{"shard": 1, "sample": 6}'}
{'__key__': '18', 'json': '{"shard": 1, "sample": 7}'}
{'__key__': '19', 'json': '{"shard": 1, "sample": 8}'}
{'__key__': '20', 'json': '{"shard": 1, "sample": 9}'}
{'__key__': '21', 'json': '{"shard": 1, "sample": 10}'}
# writing shards_07_test/shards-00002.tar 11 0.0 GB 22
{'__key__': '22', 'json': '{"shard": 2, "sample": 0}'}
{'__key__': '23', 'json': '{"shard": 2, "sample": 1}'}
{'__key__': '24', 'json': '{"shard": 2, "sample": 2}'}
{'__key__': '25', 'json': '{"shard": 2, "sample": 3}'}
{'__key__': '26', 'json': '{"shard": 2, "sample": 4}'}
{'__key__': '27', 'json': '{"shard": 2, "sample": 5}'}
{'__key__': '28', 'json': '{"shard": 2, "sample": 6}'}
{'__key__': '29', 'json': '{"shard": 2, "sample": 7}'}
{'__key__': '30', 'json': '{"shard": 2, "sample": 8}'}
{'__key__': '31', 'json': '{"shard": 2, "sample": 9}'}
{'__key__': '32', 'json': '{"shard": 2, "sample": 10}'}
# writing shards_07_test/shards-00003.tar 11 0.0 GB 33
{'__key__': '33', 'json': '{"shard": 3, "sample": 0}'}
{'__key__': '34', 'json': '{"shard": 3, "sample": 1}'}
{'__key__': '35', 'json': '{"shard": 3, "sample": 2}'}
{'__key__': '36', 'json': '{"shard": 3, "sample": 3}'}
{'__key__': '37', 'json': '{"shard": 3, "sample": 4}'}
{'__key__': '38', 'json': '{"shard": 3, "sample": 5}'}
{'__key__': '39', 'json': '{"shard": 3, "sample": 6}'}
{'__key__': '40', 'json': '{"shard": 3, "sample": 7}'}
{'__key__': '41', 'json': '{"shard": 3, "sample": 8}'}
{'__key__': '42', 'json': '{"shard": 3, "sample": 9}'}
{'__key__': '43', 'json': '{"shard": 3, "sample": 10}'}
# writing shards_07_test/shards-00004.tar 11 0.0 GB 44
{'__key__': '44', 'json': '{"shard": 4, "sample": 0}'}
{'__key__': '45', 'json': '{"shard": 4, "sample": 1}'}
{'__key__': '46', 'json': '{"shard": 4, "sample": 2}'}
{'__key__': '47', 'json': '{"shard": 4, "sample": 3}'}
{'__key__': '48', 'json': '{"shard": 4, "sample": 4}'}
{'__key__': '49', 'json': '{"shard": 4, "sample": 5}'}
{'__key__': '50', 'json': '{"shard": 4, "sample": 6}'}
{'__key__': '51', 'json': '{"shard": 4, "sample": 7}'}
{'__key__': '52', 'json': '{"shard": 4, "sample": 8}'}
{'__key__': '53', 'json': '{"shard": 4, "sample": 9}'}
{'__key__': '54', 'json': '{"shard": 4, "sample": 10}'}
# writing shards_07_test/shards-00005.tar 11 0.0 GB 55
{'__key__': '55', 'json': '{"shard": 5, "sample": 0}'}
{'__key__': '56', 'json': '{"shard": 5, "sample": 1}'}
12.2. シャッフルなしのshardよみこみ
ではこのshardファイルを読み込んでみます.
まずはシャッフルなしの場合.シャッフルの効果を見るために,shardファイルのリストをsortしておきます.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def get_dataset(shard_dir):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(shards_list)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
lambda x: x
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
dataloader = DataLoader(
get_dataset(
shard_dir='shards_07_test',
),
batch_size=5)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([0, 1, 2, 3, 4])}]
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([5, 6, 7, 8, 9])}]
Epoch 0: [{'shard': tensor([0, 1, 1, 1, 1]), 'sample': tensor([10, 0, 1, 2, 3])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([4, 5, 6, 7, 8])}]
Epoch 0: [{'shard': tensor([1, 1, 2, 2, 2]), 'sample': tensor([ 9, 10, 0, 1, 2])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([3, 4, 5, 6, 7])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 3, 3]), 'sample': tensor([ 8, 9, 10, 0, 1])}]
Epoch 0: [{'shard': tensor([3, 3, 3, 3, 3]), 'sample': tensor([2, 3, 4, 5, 6])}]
Epoch 0: [{'shard': tensor([3, 3, 3, 3, 4]), 'sample': tensor([ 7, 8, 9, 10, 0])}]
Epoch 0: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([1, 2, 3, 4, 5])}]
Epoch 0: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([ 6, 7, 8, 9, 10])}]
Epoch 0: [{'shard': tensor([5, 5]), 'sample': tensor([0, 1])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([0, 1, 2, 3, 4])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([5, 6, 7, 8, 9])}]
Epoch 1: [{'shard': tensor([0, 1, 1, 1, 1]), 'sample': tensor([10, 0, 1, 2, 3])}]
Epoch 1: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([4, 5, 6, 7, 8])}]
Epoch 1: [{'shard': tensor([1, 1, 2, 2, 2]), 'sample': tensor([ 9, 10, 0, 1, 2])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([3, 4, 5, 6, 7])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 3, 3]), 'sample': tensor([ 8, 9, 10, 0, 1])}]
Epoch 1: [{'shard': tensor([3, 3, 3, 3, 3]), 'sample': tensor([2, 3, 4, 5, 6])}]
Epoch 1: [{'shard': tensor([3, 3, 3, 3, 4]), 'sample': tensor([ 7, 8, 9, 10, 0])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([1, 2, 3, 4, 5])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([ 6, 7, 8, 9, 10])}]
Epoch 1: [{'shard': tensor([5, 5]), 'sample': tensor([0, 1])}]
shardを順番に読み込み,各shardでは順番にサンプルが取りだされてバッチになっていることがわかります.
12.3. シャッフルありのshard読み込み
次はシャッフルを入れます.dataset.shuffle()
をパイプラインに挿入するだけです.引数には,シャッフル用のバッファサイズを指定します.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def get_dataset(shard_dir, shuffle_buf_size):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(shards_list)
dataset = dataset.shuffle(shuffle_buf_size)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
lambda x: x
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
print()
print('shuffle buffer size: 3')
dataloader = DataLoader(
get_dataset(
shard_dir='shards_07_test',
shuffle_buf_size=3
),
batch_size=5)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
print()
print('shuffle buffer size: 7')
dataloader = DataLoader(
get_dataset(
shard_dir='shards_07_test',
shuffle_buf_size=7
),
batch_size=5)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
shuffle buffer size: 3
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([1, 0, 2, 4, 3])}]
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([ 7, 5, 8, 10, 9])}]
Epoch 0: [{'shard': tensor([1, 0, 1, 1, 1]), 'sample': tensor([0, 6, 2, 3, 4])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([1, 7, 5, 8, 6])}]
Epoch 0: [{'shard': tensor([1, 2, 2, 1, 2]), 'sample': tensor([ 9, 1, 2, 10, 3])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([0, 6, 5, 4, 9])}]
Epoch 0: [{'shard': tensor([2, 2, 3, 3, 3]), 'sample': tensor([8, 7, 0, 1, 3])}]
Epoch 0: [{'shard': tensor([3, 3, 2, 3, 3]), 'sample': tensor([ 2, 4, 10, 5, 8])}]
Epoch 0: [{'shard': tensor([3, 3, 4, 3, 4]), 'sample': tensor([6, 7, 0, 9, 2])}]
Epoch 0: [{'shard': tensor([4, 4, 3, 4, 4]), 'sample': tensor([ 3, 1, 10, 4, 7])}]
Epoch 0: [{'shard': tensor([4, 4, 4, 4, 5]), 'sample': tensor([ 5, 6, 9, 10, 0])}]
Epoch 0: [{'shard': tensor([4, 5]), 'sample': tensor([8, 1])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([2, 0, 4, 3, 6])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([7, 5, 1, 8, 9])}]
Epoch 1: [{'shard': tensor([0, 1, 1, 1, 1]), 'sample': tensor([10, 0, 1, 3, 2])}]
Epoch 1: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([6, 4, 7, 9, 5])}]
Epoch 1: [{'shard': tensor([2, 1, 2, 1, 2]), 'sample': tensor([ 0, 8, 1, 10, 2])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([3, 5, 4, 6, 7])}]
Epoch 1: [{'shard': tensor([2, 2, 3, 3, 3]), 'sample': tensor([9, 8, 0, 2, 1])}]
Epoch 1: [{'shard': tensor([3, 3, 2, 3, 3]), 'sample': tensor([ 3, 4, 10, 7, 8])}]
Epoch 1: [{'shard': tensor([3, 3, 4, 3, 3]), 'sample': tensor([ 5, 10, 0, 6, 9])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([3, 4, 1, 5, 6])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 4, 5]), 'sample': tensor([ 2, 7, 10, 9, 1])}]
Epoch 1: [{'shard': tensor([4, 5]), 'sample': tensor([8, 0])}]
shuffle buffer size: 7
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([3, 4, 5, 2, 1])}]
Epoch 0: [{'shard': tensor([0, 1, 1, 0, 0]), 'sample': tensor([ 8, 1, 2, 10, 0])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 0, 1]), 'sample': tensor([0, 5, 7, 7, 8])}]
Epoch 0: [{'shard': tensor([0, 1, 1, 1, 2]), 'sample': tensor([9, 4, 3, 6, 0])}]
Epoch 0: [{'shard': tensor([2, 2, 1, 2, 2]), 'sample': tensor([1, 4, 9, 3, 7])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 2, 3]), 'sample': tensor([ 5, 2, 10, 8, 2])}]
Epoch 0: [{'shard': tensor([2, 0, 3, 3, 3]), 'sample': tensor([9, 6, 5, 3, 7])}]
Epoch 0: [{'shard': tensor([3, 1, 2, 3, 3]), 'sample': tensor([ 6, 10, 6, 10, 4])}]
Epoch 0: [{'shard': tensor([3, 3, 3, 4, 3]), 'sample': tensor([8, 9, 1, 3, 0])}]
Epoch 0: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([0, 8, 4, 6, 7])}]
Epoch 0: [{'shard': tensor([4, 5, 4, 4, 5]), 'sample': tensor([5, 0, 9, 1, 1])}]
Epoch 0: [{'shard': tensor([4, 4]), 'sample': tensor([ 2, 10])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([3, 1, 0, 7, 8])}]
Epoch 1: [{'shard': tensor([1, 0, 0, 0, 0]), 'sample': tensor([ 0, 4, 5, 9, 10])}]
Epoch 1: [{'shard': tensor([1, 1, 0, 1, 1]), 'sample': tensor([2, 1, 2, 8, 6])}]
Epoch 1: [{'shard': tensor([1, 0, 1, 2, 1]), 'sample': tensor([ 5, 6, 9, 2, 10])}]
Epoch 1: [{'shard': tensor([2, 1, 2, 2, 2]), 'sample': tensor([4, 7, 6, 7, 0])}]
Epoch 1: [{'shard': tensor([2, 2, 1, 2, 1]), 'sample': tensor([3, 8, 3, 9, 4])}]
Epoch 1: [{'shard': tensor([3, 3, 3, 3, 3]), 'sample': tensor([3, 2, 5, 4, 1])}]
Epoch 1: [{'shard': tensor([3, 2, 3, 2, 3]), 'sample': tensor([ 8, 10, 7, 5, 6])}]
Epoch 1: [{'shard': tensor([4, 3, 2, 3, 4]), 'sample': tensor([1, 9, 1, 0, 3])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([ 5, 4, 2, 10, 6])}]
Epoch 1: [{'shard': tensor([4, 5, 4, 4, 5]), 'sample': tensor([7, 1, 0, 8, 0])}]
Epoch 1: [{'shard': tensor([3, 4]), 'sample': tensor([10, 9])}]
シャッフル用のバッファサイズを大きくすると,バッチに詰め込まれる順番がランダムになっていくことがわかります.コードを見ると,バッファサイズのリストにshardから順番に取り出したサンプルを詰め込んで,そのリストからランダムに取り出していることがわかります.
したがって,次のことが言えます.webdatasetでshuffle()
を使っても
- 完全にランダムにサンプルを抽出することはできません(だからこれまでのshard作成例では,作成時点でサンプルをランダムにシャッフルしてshardに詰め込んでいました).シャッフル用のバッファサイズをデータセットサイズより大きくすれば完全にランダムになりますが,データセットが大きくなればバッファを格納するメモリ領域が膨大になるので実用的ではありません.
- shardの順番はランダムにはなりません.最初に指定した順番通りにshardファイルが使われます.
12.4. 読み込むshardのシャッフル
shardの順番をシャッフルするには,WebDataset
のコンストラクタの引数でshardshuffle
をTrueにします.こうすると,エポック毎に読み込むshardリストをシャッフルします.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def get_dataset(shard_dir, shuffle_buf_size, shardshuffle):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(
shards_list,
shardshuffle=shardshuffle
)
dataset = dataset.shuffle(shuffle_buf_size)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
lambda x: x
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
print()
print('shuffle buffer size: 3')
dataloader = DataLoader(
get_dataset(
shard_dir='shards_07_test',
shuffle_buf_size=3,
shardshuffle=True,
),
batch_size=5)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
shuffle buffer size: 3
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([0, 3, 1, 2, 6])}]
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([ 7, 5, 8, 4, 10])}]
Epoch 0: [{'shard': tensor([4, 4, 0, 4, 4]), 'sample': tensor([0, 2, 9, 1, 3])}]
Epoch 0: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([5, 4, 8, 9, 7])}]
Epoch 0: [{'shard': tensor([4, 5, 3, 3, 3]), 'sample': tensor([6, 0, 0, 1, 2])}]
Epoch 0: [{'shard': tensor([5, 3, 3, 4, 3]), 'sample': tensor([ 1, 4, 3, 10, 5])}]
Epoch 0: [{'shard': tensor([3, 3, 3, 2, 2]), 'sample': tensor([8, 6, 9, 0, 1])}]
Epoch 0: [{'shard': tensor([3, 2, 2, 2, 2]), 'sample': tensor([7, 2, 4, 3, 5])}]
Epoch 0: [{'shard': tensor([2, 2, 3, 2, 2]), 'sample': tensor([ 7, 8, 10, 6, 9])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 2, 1]), 'sample': tensor([ 1, 0, 3, 10, 4])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([ 5, 2, 8, 6, 10])}]
Epoch 0: [{'shard': tensor([1, 1]), 'sample': tensor([9, 7])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([0, 2, 3, 1, 5])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([ 4, 8, 7, 9, 10])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 2]), 'sample': tensor([0, 1, 2, 3, 6])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([4, 5, 6, 9, 8])}]
Epoch 1: [{'shard': tensor([0, 0, 3, 3, 3]), 'sample': tensor([ 7, 10, 2, 1, 3])}]
Epoch 1: [{'shard': tensor([3, 3, 3, 3, 3]), 'sample': tensor([5, 4, 6, 8, 0])}]
Epoch 1: [{'shard': tensor([3, 3, 1, 3, 1]), 'sample': tensor([ 9, 7, 1, 10, 3])}]
Epoch 1: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([2, 5, 0, 4, 8])}]
Epoch 1: [{'shard': tensor([1, 1, 1, 4, 1]), 'sample': tensor([ 9, 6, 10, 0, 7])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([3, 4, 1, 5, 2])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 5, 5]), 'sample': tensor([7, 6, 9, 0, 1])}]
Epoch 1: [{'shard': tensor([4, 4]), 'sample': tensor([10, 8])}]
したがってwebdatasetでshardを読み込んで学習するときには,完全なランダムサンプリングの代用として,shardshuffle=True
を使い,そこそこのバッファサイズでshuffle()
を併用するのがよいでしょう.
13. 複数ワーカーでの読み込み
それではDataloaderで複数ワーカーを使ってshardを読み込んでみましょう.
どのワーカーがどのshardを読み込んでいるのかを確認するために,pytorchのget_worker_info()
を使います.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def add_worker_id(x):
x['worker'] = torch.utils.data.get_worker_info().id
return x
def get_dataset(shard_dir, shuffle_buf_size=0, shardshuffle=False):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(
shards_list,
shardshuffle=shardshuffle
)
dataset = dataset.shuffle(shuffle_buf_size)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
add_worker_id
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
dataloader = DataLoader(
get_dataset(
shard_dir='shards_07_test',
),
batch_size=5,
num_workers=3
)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:566: UserWarning: This DataLoader will create 3 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([0, 1, 2, 3, 4]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([0, 1, 2, 3, 4]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([0, 1, 2, 3, 4]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([5, 6, 7, 8, 9]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([5, 6, 7, 8, 9]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([5, 6, 7, 8, 9]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 0: [{'shard': tensor([0, 3, 3, 3, 3]), 'sample': tensor([10, 0, 1, 2, 3]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 4, 4, 4, 4]), 'sample': tensor([10, 0, 1, 2, 3]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([2, 5, 5]), 'sample': tensor([10, 0, 1]), 'worker': tensor([2, 2, 2])}]
Epoch 0: [{'shard': tensor([3, 3, 3, 3, 3]), 'sample': tensor([4, 5, 6, 7, 8]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([4, 5, 6, 7, 8]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([3, 3]), 'sample': tensor([ 9, 10]), 'worker': tensor([0, 0])}]
Epoch 0: [{'shard': tensor([4, 4]), 'sample': tensor([ 9, 10]), 'worker': tensor([1, 1])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([0, 1, 2, 3, 4]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([0, 1, 2, 3, 4]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([0, 1, 2, 3, 4]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([5, 6, 7, 8, 9]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([5, 6, 7, 8, 9]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([5, 6, 7, 8, 9]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 1: [{'shard': tensor([0, 3, 3, 3, 3]), 'sample': tensor([10, 0, 1, 2, 3]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 4, 4, 4, 4]), 'sample': tensor([10, 0, 1, 2, 3]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([2, 5, 5]), 'sample': tensor([10, 0, 1]), 'worker': tensor([2, 2, 2])}]
Epoch 1: [{'shard': tensor([3, 3, 3, 3, 3]), 'sample': tensor([4, 5, 6, 7, 8]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([4, 5, 6, 7, 8]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([3, 3]), 'sample': tensor([ 9, 10]), 'worker': tensor([0, 0])}]
Epoch 1: [{'shard': tensor([4, 4]), 'sample': tensor([ 9, 10]), 'worker': tensor([1, 1])}]
これを見ると,3つのワーカーが9個のshardを先頭から順番に担当していることがわかります.
- worker0の担当shardは0, 3, 6, 9
- worker1の担当shardは1, 4, 7
- worker2の担当shardは2, 5, 8
以下のように,バッファサイズを指定してshuffle()
でシャッフルすると分かるように,シャッフルでも各workerの担当shardは変わりません.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def add_worker_id(x):
x['worker'] = torch.utils.data.get_worker_info().id
return x
def get_dataset(shard_dir, shuffle_buf_size=0, shardshuffle=False):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(
shards_list,
shardshuffle=shardshuffle
)
dataset = dataset.shuffle(shuffle_buf_size)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
add_worker_id
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
dataloader = DataLoader(
get_dataset(
shard_dir='shards_07_test',
shuffle_buf_size=7,
),
batch_size=5,
num_workers=3
)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([4, 1, 8, 0, 2]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([0, 3, 2, 9, 1]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([2, 6, 5, 4, 1]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 0: [{'shard': tensor([0, 0, 0, 3, 3]), 'sample': tensor([ 5, 10, 3, 2, 1]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 4, 4, 4, 1]), 'sample': tensor([ 8, 1, 0, 2, 10]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 5, 2]), 'sample': tensor([0, 7, 8, 0, 3]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 0: [{'shard': tensor([0, 3, 3, 3, 3]), 'sample': tensor([9, 3, 0, 4, 9]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 4, 4]), 'sample': tensor([5, 7, 4, 6, 4]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([5, 2, 2]), 'sample': tensor([ 1, 10, 9]), 'worker': tensor([2, 2, 2])}]
Epoch 0: [{'shard': tensor([0, 3, 0, 3, 3]), 'sample': tensor([6, 5, 7, 6, 7]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 4, 4, 4, 4]), 'sample': tensor([6, 7, 9, 5, 8]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([3, 3]), 'sample': tensor([ 8, 10]), 'worker': tensor([0, 0])}]
Epoch 0: [{'shard': tensor([4, 4]), 'sample': tensor([10, 3]), 'worker': tensor([1, 1])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([3, 0, 6, 2, 9]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([1, 6, 7, 5, 9]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([ 4, 3, 1, 8, 10]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 1: [{'shard': tensor([0, 3, 0, 3, 0]), 'sample': tensor([ 7, 0, 10, 1, 8]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 1, 4, 4, 4]), 'sample': tensor([3, 4, 0, 3, 4]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 5, 2]), 'sample': tensor([7, 2, 6, 0, 5]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 1: [{'shard': tensor([3, 3, 3, 0, 0]), 'sample': tensor([2, 4, 3, 5, 1]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([4, 4, 1, 1, 4]), 'sample': tensor([5, 2, 0, 8, 7]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([2, 5, 2]), 'sample': tensor([9, 1, 0]), 'worker': tensor([2, 2, 2])}]
Epoch 1: [{'shard': tensor([3, 0, 3, 3, 3]), 'sample': tensor([ 9, 4, 7, 6, 10]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([4, 4, 4, 1, 4]), 'sample': tensor([ 1, 6, 9, 10, 8]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([3, 3]), 'sample': tensor([5, 8]), 'worker': tensor([0, 0])}]
Epoch 1: [{'shard': tensor([4, 1]), 'sample': tensor([10, 2]), 'worker': tensor([1, 1])}]
以下のようにshardsuffle=True
でshardファイルをシャッフルしても,各ワーカーの担当shard内でshardのシャッフルが行われるだけで,ワーカーをまたいでサンプルがシャッフルされることはありません.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def add_worker_id(x):
x['worker'] = torch.utils.data.get_worker_info().id
return x
def get_dataset(shard_dir, shuffle_buf_size=0, shardshuffle=False):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(
shards_list,
shardshuffle=shardshuffle
)
dataset = dataset.shuffle(shuffle_buf_size)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
add_worker_id
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
dataloader = DataLoader(
get_dataset(
shard_dir='shards_07_test',
shuffle_buf_size=7,
shardshuffle=True
),
batch_size=5,
num_workers=3
)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
Epoch 0: [{'shard': tensor([3, 3, 3, 3, 3]), 'sample': tensor([6, 0, 7, 4, 9]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([0, 2, 1, 5, 7]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 5, 5]), 'sample': tensor([0, 3, 2, 0, 1]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 0: [{'shard': tensor([0, 0, 3, 3, 3]), 'sample': tensor([0, 1, 8, 3, 2]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 1, 1, 1, 4]), 'sample': tensor([ 4, 10, 9, 6, 1]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([ 4, 1, 5, 7, 10]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 0: [{'shard': tensor([3, 0, 3, 0, 0]), 'sample': tensor([ 5, 5, 10, 3, 4]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([1, 4, 4, 4, 4]), 'sample': tensor([3, 4, 3, 2, 6]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([2, 2, 2]), 'sample': tensor([6, 8, 9]), 'worker': tensor([2, 2, 2])}]
Epoch 0: [{'shard': tensor([0, 0, 0, 0, 0]), 'sample': tensor([9, 6, 7, 2, 8]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 0: [{'shard': tensor([4, 4, 4, 4, 4]), 'sample': tensor([ 8, 9, 7, 10, 5]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 0: [{'shard': tensor([0, 3]), 'sample': tensor([10, 1]), 'worker': tensor([0, 0])}]
Epoch 0: [{'shard': tensor([4, 1]), 'sample': tensor([0, 8]), 'worker': tensor([1, 1])}]
Epoch 1: [{'shard': tensor([3, 3, 3, 3, 3]), 'sample': tensor([5, 7, 4, 2, 0]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 1, 1, 1, 1]), 'sample': tensor([0, 7, 3, 2, 8]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([2, 5, 1, 3, 0]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 1: [{'shard': tensor([3, 3, 3, 3, 0]), 'sample': tensor([9, 6, 8, 1, 4]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 1, 4, 1, 4]), 'sample': tensor([1, 5, 1, 4, 2]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([2, 2, 2, 2, 2]), 'sample': tensor([ 7, 10, 9, 8, 4]), 'worker': tensor([2, 2, 2, 2, 2])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 0, 3]), 'sample': tensor([ 2, 6, 5, 8, 10]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 4, 1, 4, 4]), 'sample': tensor([ 6, 5, 10, 4, 3]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([5, 2, 5]), 'sample': tensor([0, 6, 1]), 'worker': tensor([2, 2, 2])}]
Epoch 1: [{'shard': tensor([0, 0, 0, 3, 0]), 'sample': tensor([ 1, 10, 3, 3, 7]), 'worker': tensor([0, 0, 0, 0, 0])}]
Epoch 1: [{'shard': tensor([1, 4, 4, 4, 4]), 'sample': tensor([ 9, 10, 0, 8, 6]), 'worker': tensor([1, 1, 1, 1, 1])}]
Epoch 1: [{'shard': tensor([0, 0]), 'sample': tensor([0, 9]), 'worker': tensor([0, 0])}]
Epoch 1: [{'shard': tensor([4, 4]), 'sample': tensor([7, 9]), 'worker': tensor([1, 1])}]
13.1. wds.WebLoaderの利用
ワーカーをまたいでサンプルをシャッフルするには,wds.WebLoader
を使います(WebLoader
はそのためだけのものではありませんが).これはpytorchのDataLoader
に機能を追加したものです.
以下では,dataset
ではnum_worker
だけを指定し,batch_size
にはNoneを指定します.つまりdataset
の段階ではバッチを作成しません.このdataset
を読み込むWebLoader
のほうで,シャッフルとバッチ作成をおこないます.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def add_worker_id(x):
x['worker'] = torch.utils.data.get_worker_info().id
return x
def get_dataset(shard_dir, shuffle_buf_size=0, shardshuffle=False):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(
shards_list,
shardshuffle=shardshuffle
)
dataset = dataset.shuffle(shuffle_buf_size)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
add_worker_id
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
def my_collate(batch):
return torch.utils.data.default_collate([b[0] for b in batch])
dataloader = wds.WebLoader(
get_dataset(
shard_dir='shards_07_test',
),
batch_size=None,
num_workers=3,
)
dataloader = dataloader.shuffle(7)
dataloader = dataloader.batched(
batchsize=5,
collation_fn=my_collate
)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
Epoch 0: {'shard': tensor([2, 1, 2, 0, 1]), 'sample': tensor([1, 0, 2, 2, 1]), 'worker': tensor([2, 1, 2, 0, 1])}
Epoch 0: {'shard': tensor([1, 2, 0, 0, 0]), 'sample': tensor([3, 0, 1, 3, 5]), 'worker': tensor([1, 2, 0, 0, 0])}
Epoch 0: {'shard': tensor([1, 2, 0, 2, 2]), 'sample': tensor([2, 5, 0, 4, 3]), 'worker': tensor([1, 2, 0, 2, 2])}
Epoch 0: {'shard': tensor([0, 1, 1, 1, 2]), 'sample': tensor([7, 4, 6, 7, 7]), 'worker': tensor([0, 1, 1, 1, 2])}
Epoch 0: {'shard': tensor([0, 0, 0, 0, 1]), 'sample': tensor([8, 4, 6, 9, 8]), 'worker': tensor([0, 0, 0, 0, 1])}
Epoch 0: {'shard': tensor([1, 2, 2, 3, 5]), 'sample': tensor([5, 8, 9, 0, 0]), 'worker': tensor([1, 2, 2, 0, 2])}
Epoch 0: {'shard': tensor([2, 2, 1, 3, 4]), 'sample': tensor([10, 6, 10, 2, 0]), 'worker': tensor([2, 2, 1, 0, 1])}
Epoch 0: {'shard': tensor([0, 3, 5, 3, 4]), 'sample': tensor([10, 3, 1, 1, 4]), 'worker': tensor([0, 0, 2, 0, 1])}
Epoch 0: {'shard': tensor([1, 3, 4, 4, 4]), 'sample': tensor([9, 6, 2, 1, 3]), 'worker': tensor([1, 0, 1, 1, 1])}
Epoch 0: {'shard': tensor([3, 4, 3, 3, 4]), 'sample': tensor([7, 6, 5, 4, 8]), 'worker': tensor([0, 1, 0, 0, 1])}
Epoch 0: {'shard': tensor([4, 4, 4, 4, 3]), 'sample': tensor([ 7, 5, 9, 10, 8]), 'worker': tensor([1, 1, 1, 1, 0])}
Epoch 0: {'shard': tensor([3, 3]), 'sample': tensor([ 9, 10]), 'worker': tensor([0, 0])}
Epoch 1: {'shard': tensor([0, 2, 1, 0, 1]), 'sample': tensor([1, 0, 1, 0, 2]), 'worker': tensor([0, 2, 1, 0, 1])}
Epoch 1: {'shard': tensor([0, 2, 0, 1, 0]), 'sample': tensor([3, 3, 4, 4, 5]), 'worker': tensor([0, 2, 0, 1, 0])}
Epoch 1: {'shard': tensor([2, 2, 0, 2, 1]), 'sample': tensor([1, 5, 2, 4, 0]), 'worker': tensor([2, 2, 0, 2, 1])}
Epoch 1: {'shard': tensor([1, 2, 0, 0, 1]), 'sample': tensor([5, 6, 7, 6, 8]), 'worker': tensor([1, 2, 0, 0, 1])}
Epoch 1: {'shard': tensor([1, 2, 0, 2, 1]), 'sample': tensor([6, 7, 8, 9, 3]), 'worker': tensor([1, 2, 0, 2, 1])}
Epoch 1: {'shard': tensor([2, 1, 2, 2, 1]), 'sample': tensor([ 8, 9, 10, 2, 7]), 'worker': tensor([2, 1, 2, 2, 1])}
Epoch 1: {'shard': tensor([0, 0, 5, 4, 4]), 'sample': tensor([ 9, 10, 1, 0, 2]), 'worker': tensor([0, 0, 2, 1, 1])}
Epoch 1: {'shard': tensor([3, 3, 3, 4, 3]), 'sample': tensor([0, 1, 3, 3, 5]), 'worker': tensor([0, 0, 0, 1, 0])}
Epoch 1: {'shard': tensor([3, 4, 4, 4, 4]), 'sample': tensor([4, 5, 1, 6, 4]), 'worker': tensor([0, 1, 1, 1, 1])}
Epoch 1: {'shard': tensor([3, 4, 3, 1, 5]), 'sample': tensor([ 6, 7, 7, 10, 0]), 'worker': tensor([0, 1, 0, 1, 2])}
Epoch 1: {'shard': tensor([4, 3, 3, 3, 4]), 'sample': tensor([ 9, 9, 10, 2, 10]), 'worker': tensor([1, 0, 0, 0, 1])}
Epoch 1: {'shard': tensor([4, 3]), 'sample': tensor([8, 8]), 'worker': tensor([1, 0])}
なおwebdatasetのサンプルでは,以下のように,datasetが作成したバッチをloaderの方で一度unbatch()
でばらばらにして,それからシャッフルし,再度batched()
でバッチを作成する例があります(バッチサイズを変更することもできます).
loader = wds.WebLoader(dataset, num_workers=4, batch_size=8)
loader = loader.unbatched().shuffle(1000).batched(12)
ただしコードを見るとunbatch()
はタプルを仮定しているようなので,上記のコードのようなdict形式のサンプルはunbatchできません.
13.1.1. drop_last=Trueはpartial=False
ちなみにDataLoader
のdrop_last=True
と同等にするには,dataloader.batched()
の引数にpartial=False
を指定します.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def add_worker_id(x):
x['worker'] = torch.utils.data.get_worker_info().id
return x
def get_dataset(shard_dir, shuffle_buf_size=0, shardshuffle=False):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(
shards_list,
shardshuffle=shardshuffle
)
dataset = dataset.shuffle(shuffle_buf_size)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
add_worker_id
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
def my_collate(batch):
return torch.utils.data.default_collate([b[0] for b in batch])
dataloader = wds.WebLoader(
get_dataset(
shard_dir='shards_07_test',
),
batch_size=None,
num_workers=3,
)
dataloader = dataloader.shuffle(7)
dataloader = dataloader.batched(
batchsize=5,
collation_fn=my_collate,
partial=False,
)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
Epoch 0: {'shard': tensor([1, 0, 0, 2, 1]), 'sample': tensor([0, 1, 0, 0, 1]), 'worker': tensor([1, 0, 0, 2, 1])}
Epoch 0: {'shard': tensor([2, 0, 0, 2, 1]), 'sample': tensor([1, 3, 4, 3, 4]), 'worker': tensor([2, 0, 0, 2, 1])}
Epoch 0: {'shard': tensor([2, 2, 1, 2, 0]), 'sample': tensor([2, 4, 3, 5, 6]), 'worker': tensor([2, 2, 1, 2, 0])}
Epoch 0: {'shard': tensor([1, 0, 1, 0, 1]), 'sample': tensor([2, 2, 7, 8, 8]), 'worker': tensor([1, 0, 1, 0, 1])}
Epoch 0: {'shard': tensor([0, 2, 2, 2, 0]), 'sample': tensor([7, 7, 6, 9, 9]), 'worker': tensor([0, 2, 2, 2, 0])}
Epoch 0: {'shard': tensor([1, 0, 2, 1, 0]), 'sample': tensor([10, 5, 10, 6, 10]), 'worker': tensor([1, 0, 2, 1, 0])}
Epoch 0: {'shard': tensor([3, 5, 5, 3, 2]), 'sample': tensor([0, 0, 1, 2, 8]), 'worker': tensor([0, 2, 2, 0, 2])}
Epoch 0: {'shard': tensor([3, 4, 3, 3, 1]), 'sample': tensor([3, 3, 4, 1, 9]), 'worker': tensor([0, 1, 0, 0, 1])}
Epoch 0: {'shard': tensor([4, 4, 1, 3, 3]), 'sample': tensor([1, 2, 5, 7, 6]), 'worker': tensor([1, 1, 1, 0, 0])}
Epoch 0: {'shard': tensor([3, 4, 4, 3, 4]), 'sample': tensor([8, 5, 4, 9, 9]), 'worker': tensor([0, 1, 1, 0, 1])}
Epoch 0: {'shard': tensor([4, 3, 3, 4, 4]), 'sample': tensor([10, 10, 5, 0, 6]), 'worker': tensor([1, 0, 0, 1, 1])}
Epoch 1: {'shard': tensor([0, 1, 2, 1, 0]), 'sample': tensor([2, 0, 0, 2, 1]), 'worker': tensor([0, 1, 2, 1, 0])}
Epoch 1: {'shard': tensor([2, 1, 1, 0, 2]), 'sample': tensor([2, 1, 3, 3, 3]), 'worker': tensor([2, 1, 1, 0, 2])}
Epoch 1: {'shard': tensor([0, 2, 1, 2, 0]), 'sample': tensor([4, 4, 5, 1, 5]), 'worker': tensor([0, 2, 1, 2, 0])}
Epoch 1: {'shard': tensor([1, 2, 0, 0, 1]), 'sample': tensor([4, 5, 0, 8, 7]), 'worker': tensor([1, 2, 0, 0, 1])}
Epoch 1: {'shard': tensor([0, 1, 1, 0, 2]), 'sample': tensor([7, 8, 6, 9, 7]), 'worker': tensor([0, 1, 1, 0, 2])}
Epoch 1: {'shard': tensor([1, 2, 3, 2, 0]), 'sample': tensor([10, 6, 0, 8, 10]), 'worker': tensor([1, 2, 0, 2, 0])}
Epoch 1: {'shard': tensor([0, 1, 3, 3, 2]), 'sample': tensor([ 6, 9, 1, 2, 10]), 'worker': tensor([0, 1, 0, 0, 2])}
Epoch 1: {'shard': tensor([5, 4, 3, 2, 3]), 'sample': tensor([1, 3, 3, 9, 5]), 'worker': tensor([2, 1, 0, 2, 0])}
Epoch 1: {'shard': tensor([4, 4, 4, 5, 4]), 'sample': tensor([2, 4, 6, 0, 5]), 'worker': tensor([1, 1, 1, 2, 1])}
Epoch 1: {'shard': tensor([3, 3, 3, 4, 3]), 'sample': tensor([ 6, 8, 9, 1, 10]), 'worker': tensor([0, 0, 0, 1, 0])}
Epoch 1: {'shard': tensor([4, 4, 3, 4, 4]), 'sample': tensor([10, 9, 4, 0, 7]), 'worker': tensor([1, 1, 0, 1, 1])}
したがって複数ワーカーを利用する場合のシャッフルには,
- まずdatasetの方でそこそこのバッファサイズでshuffleし,
- さらにshardshuffleを指定してエポック毎に読み込むshardの順番もシャッフルし,
- さらにWebLoaderでバッチを作成する際にもそこそこのバッファサイズでシャッフルする,
という3重のシャッフルを利用することができます.以下がその例です.
実際にはバッファメモリを確保するのがCPUなのかGPUなのかによって,サイズを考慮したほうよいでしょう.
import torch
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
def add_worker_id(x):
x['worker'] = torch.utils.data.get_worker_info().id
return x
def get_dataset(shard_dir, shuffle_buf_size=0, shardshuffle=False):
def info_from_json(shard_dir):
with open(Path(shard_dir) / 'dataset-size.json', 'r') as f:
info_dic = json.load(f)
return info_dic['dataset size']
shards_list = sorted([
str(path) for path in Path(shard_dir).glob('*.tar')
])
dataset = wds.WebDataset(
shards_list,
shardshuffle=shardshuffle
)
dataset = dataset.shuffle(shuffle_buf_size)
dataset = dataset.decode()
dataset = dataset.to_tuple('json')
dataset = dataset.map_tuple(
add_worker_id
)
dataset_size = info_from_json(shard_dir)
dataset = dataset.with_length(dataset_size)
return dataset
def my_collate(batch):
return torch.utils.data.default_collate([b[0] for b in batch])
dataloader = wds.WebLoader(
get_dataset(
shard_dir='shards_07_test',
shuffle_buf_size=5,
shardshuffle=True
),
batch_size=None,
num_workers=3,
)
dataloader = dataloader.shuffle(7)
dataloader = dataloader.batched(
batchsize=5,
collation_fn=my_collate,
partial=False,
)
n_epochs = 2
for epoch in range(n_epochs):
for sample in dataloader:
print(f'Epoch {epoch}: ', sample)
Epoch 0: {'shard': tensor([3, 3, 2, 4, 2]), 'sample': tensor([4, 5, 2, 4, 3]), 'worker': tensor([0, 0, 2, 1, 2])}
Epoch 0: {'shard': tensor([3, 3, 2, 2, 4]), 'sample': tensor([2, 8, 0, 4, 1]), 'worker': tensor([0, 0, 2, 2, 1])}
Epoch 0: {'shard': tensor([4, 3, 4, 5, 4]), 'sample': tensor([3, 3, 6, 0, 5]), 'worker': tensor([1, 0, 1, 2, 1])}
Epoch 0: {'shard': tensor([2, 3, 5, 0, 3]), 'sample': tensor([6, 7, 1, 1, 1]), 'worker': tensor([2, 0, 2, 0, 0])}
Epoch 0: {'shard': tensor([0, 4, 2, 2, 3]), 'sample': tensor([0, 9, 9, 1, 9]), 'worker': tensor([0, 1, 2, 2, 0])}
Epoch 0: {'shard': tensor([1, 4, 0, 4, 2]), 'sample': tensor([0, 7, 2, 2, 8]), 'worker': tensor([1, 1, 0, 1, 2])}
Epoch 0: {'shard': tensor([4, 2, 2, 3, 0]), 'sample': tensor([ 8, 10, 7, 6, 5]), 'worker': tensor([1, 2, 2, 0, 0])}
Epoch 0: {'shard': tensor([2, 1, 3, 1, 4]), 'sample': tensor([ 5, 1, 0, 4, 10]), 'worker': tensor([2, 1, 0, 1, 1])}
Epoch 0: {'shard': tensor([1, 1, 3, 1, 0]), 'sample': tensor([ 6, 5, 10, 3, 7]), 'worker': tensor([1, 1, 0, 1, 0])}
Epoch 0: {'shard': tensor([0, 1, 0, 1, 0]), 'sample': tensor([ 9, 2, 8, 10, 10]), 'worker': tensor([0, 1, 0, 1, 0])}
Epoch 0: {'shard': tensor([1, 0, 1, 0, 1]), 'sample': tensor([7, 4, 9, 3, 8]), 'worker': tensor([1, 0, 1, 0, 1])}
Epoch 1: {'shard': tensor([0, 2, 2, 0, 1]), 'sample': tensor([3, 3, 6, 0, 6]), 'worker': tensor([0, 2, 2, 0, 1])}
Epoch 1: {'shard': tensor([2, 1, 0, 1, 1]), 'sample': tensor([1, 5, 1, 1, 4]), 'worker': tensor([2, 1, 0, 1, 1])}
Epoch 1: {'shard': tensor([1, 0, 2, 0, 1]), 'sample': tensor([ 0, 9, 8, 10, 9]), 'worker': tensor([1, 0, 2, 0, 1])}
Epoch 1: {'shard': tensor([3, 2, 2, 2, 2]), 'sample': tensor([0, 4, 2, 5, 0]), 'worker': tensor([0, 2, 2, 2, 2])}
Epoch 1: {'shard': tensor([0, 0, 1, 0, 0]), 'sample': tensor([4, 7, 7, 2, 5]), 'worker': tensor([0, 0, 1, 0, 0])}
Epoch 1: {'shard': tensor([4, 5, 3, 4, 2]), 'sample': tensor([ 1, 1, 2, 0, 10]), 'worker': tensor([1, 2, 0, 1, 2])}
Epoch 1: {'shard': tensor([1, 3, 5, 2, 2]), 'sample': tensor([10, 4, 0, 7, 9]), 'worker': tensor([1, 0, 2, 2, 2])}
Epoch 1: {'shard': tensor([3, 1, 4, 0, 3]), 'sample': tensor([7, 8, 7, 6, 5]), 'worker': tensor([0, 1, 1, 0, 0])}
Epoch 1: {'shard': tensor([3, 3, 4, 4, 0]), 'sample': tensor([3, 6, 2, 8, 8]), 'worker': tensor([0, 0, 1, 1, 0])}
Epoch 1: {'shard': tensor([1, 4, 3, 4, 4]), 'sample': tensor([ 3, 6, 9, 4, 10]), 'worker': tensor([1, 1, 0, 1, 1])}
Epoch 1: {'shard': tensor([3, 4, 3, 3, 4]), 'sample': tensor([10, 3, 1, 8, 5]), 'worker': tensor([0, 1, 0, 0, 1])}
14. その他の話題(追記)
shardを並列処理で作成する方法,学習を並列で行うDPとDDPの記事を書きました.