前回
今回は実際にデータセットにアクセスしつつ,Dataset
とIterableDataset
の違いを把握する
Dataset
通常のデータセットクラス
>>> from datasets import load_dataset
>>> dataset = load_dataset("lerobot/pusht", split="train")
Resolving data files: 100%|██████████████████| 206/206 [00:00<00:00, 997.46it/s]
Resolving data files: 100%|██████████████████| 206/206 [00:00<00:00, 923.33it/s]
>>> dataset
Dataset({
features: ['observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.reward', 'next.done', 'next.success', 'index', 'task_index'],
num_rows: 25650
})
row indexを使って各データにアクセス可能
>>> dataset[0]
{'observation.state': [222.0, 97.0], 'action': [233.0, 71.0], 'episode_index': 0, 'frame_index': 0, 'timestamp': 0.0, 'next.reward': 0.19029748439788818, 'next.done': False, 'next.success': False, 'index': 0, 'task_index': 0}
colum nameでアクセスするとラベルが付与されたデータ全体を取得できる
>>> dataset["observation.state"]
... , [137.7550506591797, 196.30186462402344], [132.9698028564453, 217.26519775390625], [131.4213104248047, 240.771728515625], [131.673828125, 260.8636169433594], [134.3138427734375, 278.7528076171875], [138.9960479736328, 293.6759033203125], [144.55123901367188, 305.2630615234375]]
[row index]と[column name]の組み合わせで1つのデータにアクセスできる
>>> dataset[0]["observation.state"]
[222.0, 97.0]
>>> dataset["observation.state"][0]
[222.0, 97.0]
順番はどちらでもいいが,indexが先のほうが速い
>>> start_time = time.time()
>>> dataset[0]["observation.state"]
[222.0, 97.0]
>>> end_time = time.time()
>>> print(f"Elapsed time: {end_time - start_time:.4f} seconds")
Elapsed time: 0.0013 seconds
>>> start_time = time.time()
>>> dataset["observation.state"][0]
[222.0, 97.0]
>>> end_time = time.time()
>>> print(f"Elapsed time: {end_time - start_time:.4f} seconds")
Elapsed time: 0.0347 seconds
これは,column nameで
sliceもできる
>>> dataset[:3]
{'observation.state': [[222.0, 97.0], [225.2523956298828, 89.31253051757812], [227.5923309326172, 84.53437805175781]], 'action': [[233.0, 71.0], [229.0, 83.0], [229.0, 86.0]], 'episode_index': [0, 0, 0], 'frame_index': [0, 1, 2], 'timestamp': [0.0, 0.10000000149011612, 0.20000000298023224], 'next.reward': [0.19029748439788818, 0.19029748439788818, 0.19029748439788818], 'next.done': [False, False, False], 'next.success': [False, False, False], 'index': [0, 1, 2], 'task_index': [0, 0, 0]}
>>> dataset[4:9]
{'observation.state': [[229.04222106933594, 84.95709991455078], [232.16236877441406, 86.34400177001953], [238.84613037109375, 89.35484313964844], [248.09005737304688, 94.0600814819336], [258.14642333984375, 99.59150695800781]], 'action': [[239.0, 89.0], [251.0, 95.0], [263.0, 102.0], [273.0, 108.0], [283.0, 116.0]], 'episode_index': [0, 0, 0, 0, 0], 'frame_index': [4, 5, 6, 7, 8], 'timestamp': [0.4000000059604645, 0.5, 0.6000000238418579, 0.699999988079071, 0.800000011920929], 'next.reward': [0.19029748439788818, 0.19029748439788818, 0.19029748439788818, 0.19029748439788818, 0.19029748439788818], 'next.done': [False, False, False, False, False], 'next.success': [False, False, False, False, False], 'index': [4, 5, 6, 7, 8], 'task_index': [0, 0, 0, 0, 0]}
IterableDataset
より大きなデータセットの場合,データセットのダウンロードを全て待たなくともデータがiterableに扱える
>>> from datasets import load_dataset
>>> iterable_dataset = load_dataset("lerobot/pusht", split="train", streaming=True)
Resolving data files: 100%|████████████████| 206/206 [00:00<00:00, 82856.41it/s]
Resolving data files: 100%|████████████████| 206/206 [00:00<00:00, 62187.03it/s]
アクセスするときはイテレータを使う
>>> for example in iterable_dataset:
... print(example)
... break
{'observation.state': [222.0, 97.0], 'action': [233.0, 71.0], 'episode_index': 0, 'frame_index': 0, 'timestamp': 0.0, 'next.reward': 0.19029748439788818, 'next.done': False, 'next.success': False, 'index': 0, 'task_index': 0}
一度ダウンロードしたDataset
をIterableDataset
にすることも可能 こちらのほうが速い(オフラインのため)
>>> dataset = load_dataset("lerobot/pusht", split="train")
Resolving data files: 100%|███████████████| 206/206 [00:00<00:00, 104099.59it/s]
Resolving data files: 100%|███████████████| 206/206 [00:00<00:00, 114002.72it/s]
>>> dataset.to_iterable_dataset()
IterableDataset({
features: ['observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.reward', 'next.done', 'next.success', 'index', 'task_index'],
num_shards: 1
})
>>> iterable_dataset = dataset.to_iterable_dataset()
各データへのアクセスはiterator
を用いる
>>> next(iter(iterable_dataset))
{'observation.state': [222.0, 97.0], 'action': [233.0, 71.0], 'episode_index': 0, 'frame_index': 0, 'timestamp': 0.0, 'next.reward': 0.19029748439788818, 'next.done': False, 'next.success': False, 'index': 0, 'task_index': 0}
>>> for example in iterable_dataset:
... print(example)
...
{'observation.state': [81.86576080322266, 380.5188903808594], 'action': [60.0, 397.0], 'episode_index': 22, 'frame_index': 59, 'timestamp': 5.900000095367432, 'next.reward': 0.1000174880027771, 'next.done': False, 'next.success': False, 'index': 2914, 'task_index': 0}
{'observation.state': [73.93565368652344, 386.9354248046875], 'action': [51.0, 415.0], 'episode_index': 22, 'frame_index': 60, 'timestamp': 6.0, 'next.reward': 0.1000174880027771, 'next.done': False, 'next.success': False, 'index': 2915, 'task_index': 0}
{'observation.state': [64.71510314941406, 397.06890869140625], 'action': [53.0, 437.0], 'episode_index': 22, 'frame_index': 61, 'timestamp': 6.099999904632568, 'next.reward': 0.1000174880027771, 'next.done': False, 'next.success': False, 'index': 2916, 'task_index': 0}
{'observation.state': [58.69871139526367, 412.0063171386719], 'action': [72.0, 462.0], 'episode_index': 22, 'frame_index': 62, 'timestamp': 6.199999809265137, 'next.reward': 0.1000174880027771, 'next.done': False, 'next.success': False, 'index': 2917, 'task_index': 0}
take
も使える 新しいiterableDataset
を返す
>>> list(iterable_dataset.take(5))
[{'observation.state': [222.0, 97.0], 'action': [233.0, 71.0], 'episode_index': 0, 'frame_index': 0, 'timestamp': 0.0, 'next.reward': 0.19029748439788818, 'next.done': False, 'next.success': False, 'index': 0, 'task_index': 0}, {'observation.state': [225.2523956298828, 89.31253051757812], 'action': [229.0, 83.0], 'episode_index': 0, 'frame_index': 1, 'timestamp': 0.10000000149011612, 'next.reward': 0.19029748439788818, 'next.done': False, 'next.success': False, 'index': 1, 'task_index': 0}, {'observation.state': [227.5923309326172, 84.53437805175781], 'action': [229.0, 86.0], 'episode_index': 0, 'frame_index': 2, 'timestamp': 0.20000000298023224, 'next.reward': 0.19029748439788818, 'next.done': False, 'next.success': False, 'index': 2, 'task_index': 0}, {'observation.state': [228.420166015625, 84.27986145019531], 'action': [230.0, 86.0], 'episode_index': 0, 'frame_index': 3, 'timestamp': 0.30000001192092896, 'next.reward': 0.19029748439788818, 'next.done': False, 'next.success': False, 'index': 3, 'task_index': 0}, {'observation.state': [229.04222106933594, 84.95709991455078], 'action': [239.0, 89.0], 'episode_index': 0, 'frame_index': 4, 'timestamp': 0.4000000059604645, 'next.reward': 0.19029748439788818, 'next.done': False, 'next.success': False, 'index': 4, 'task_index': 0}]
以上
Link
目次