前回:
今回はLeRobot Datasetについて学ぶ
使用するデータセット
準備
python
>>> from pprint import pprint
>>>
>>> import torch
>>> from huggingface_hub import HfApi
>>>
>>> import lerobot
>>> from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
実践
Datasetを探す
多くのdatasetが提供されている
>>> pprint(lerobot.available_datasets)
['lerobot/aloha_mobile_cabinet',
'lerobot/aloha_mobile_chair',
...
datasetはhugging face api経由でも確認できる
>>> hub_api = HfApi()
>>> repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
>>> pprint(repo_ids)
web siteでも確認できる
Meta Data
実際にMeta dataをダウンロードしてみる
>>> repo_id = "lerobot/utokyo_pr2_opening_fridge"
>>> ds_meta = LeRobotDatasetMetadata(repo_id)
meta/episodes.jsonl: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5.59k/5.59k [00:00<00:00, 17.1MB/s]
meta/tasks.jsonl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48.0/48.0 [00:00<00:00, 180kB/s]
meta/stats.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.86k/4.86k [00:00<00:00, 14.4MB/s]
meta/info.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3.08k/3.08k [00:00<00:00, 11.4MB/s]
Fetching 4 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.50it/s]
meta dataのみをダウンロードすると,実際のDatasetをダウンロードすることなく中身の情報を確認できる
>>> print(f"Total number of episodes: {ds_meta.total_episodes}")
Total number of episodes: 80
>>> print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
Average number of frames per episode: 144.025
>>> print(f"Frames per second used during data collection: {ds_meta.fps}")
Frames per second used during data collection: 5
>>> print(f"Robot type: {ds_meta.robot_type}")
Robot type: unknown
>>> print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
keys to access images from cameras: ds_meta.camera_keys=['observation.images.image']
Task設定やfeatureも確認できる(今回はタスク設定がなかった)
>>> print("Tasks:")
Tasks:
>>> print(ds_meta.tasks)
{0: 'opening the fridge'}
>>> print("Features:")
Features:
>>> pprint(ds_meta.features)
{'action': {'dtype': 'float32',
'names': {'motors': ['motor_0',
'motor_1',
'motor_2',
'motor_3',
'motor_4',
'motor_5',
'motor_6',
'motor_7']},
'shape': (8,)},
'episode_index': {'dtype': 'int64', 'names': None, 'shape': (1,)},
'frame_index': {'dtype': 'int64', 'names': None, 'shape': (1,)},
'index': {'dtype': 'int64', 'names': None, 'shape': (1,)},
'language_instruction': {'dtype': 'string', 'names': None, 'shape': (1,)},
'next.done': {'dtype': 'bool', 'names': None, 'shape': (1,)},
'next.reward': {'dtype': 'float32', 'names': None, 'shape': (1,)},
'observation.images.image': {'dtype': 'video',
'names': ['height', 'width', 'channel'],
'shape': (128, 128, 3),
'video_info': {'has_audio': False,
'video.codec': 'av1',
'video.fps': 5.0,
'video.is_depth_map': False,
'video.pix_fmt': 'yuv420p'}},
'observation.state': {'dtype': 'float32',
'names': {'motors': ['motor_0',
'motor_1',
'motor_2',
'motor_3',
'motor_4',
'motor_5',
'motor_6']},
'shape': (7,)},
'task_index': {'dtype': 'int64', 'names': None, 'shape': (1,)},
'timestamp': {'dtype': 'float32', 'names': None, 'shape': (1,)}}
meta data全体の確認
>>> print(ds_meta)
LeRobotDatasetMetadata({
Repository ID: 'lerobot/utokyo_pr2_opening_fridge',
Total episodes: '80',
Total frames: '11522',
Features: '['observation.images.image', 'language_instruction', 'observation.state', 'action', 'timestamp', 'episode_index', 'frame_index', 'next.reward', 'next.done', 'index', 'task_index']',
})'
Download Dataset
実際に特定のepisodeをダウンロードしてみる
>>> dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
Fetching 4 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 626.27it/s]
episode_000023.parquet: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16.5k/16.5k [00:00<00:00, 38.3MB/s]
episode_000000.parquet: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15.5k/15.5k [00:00<00:00, 29.5MB/s]
episode_000011.parquet: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14.8k/14.8k [00:00<00:00, 37.1MB/s]
episode_000010.parquet: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17.0k/17.0k [00:00<00:00, 31.0MB/s]
episode_000000.mp4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 324k/324k [00:00<00:00, 11.0MB/s]
episode_000023.mp4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 345k/345k [00:00<00:00, 11.2MB/s]
episode_000011.mp4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 290k/290k [00:00<00:00, 9.61MB/s]
episode_000010.mp4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 365k/365k [00:00<00:00, 10.3MB/s]
Fetching 8 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 8.69it/s]
Generating train split: 514 examples [00:00, 87862.10 examples/s]
>>> print(f"Selected episodes: {dataset.episodes}")
Selected episodes: [0, 10, 11, 23]
>>> print(f"Number of episodes selected: {dataset.num_episodes}")
Number of episodes selected: 4
>>> print(f"Number of frames selected: {dataset.num_frames}")
Number of frames selected: 514
全体をダウンロードすることもできる
>>> dataset = LeRobotDataset(repo_id)
Fetching 4 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 1080.73it/s]
README.md: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3.75k/3.75k [00:00<00:00, 12.9MB/s]
.gitattributes: 100%|
...
Generating train split: 11522 examples [00:00, 162350.70 examples/s]
>>> print(f"Number of episodes selected: {dataset.num_episodes}")
Number of episodes selected: 80
>>> print(f"Number of frames selected: {dataset.num_frames}")
Number of frames selected: 11522
meta dataの確認
>>> print(dataset.meta)
LeRobotDatasetMetadata({
Repository ID: 'lerobot/utokyo_pr2_opening_fridge',
Total episodes: '80',
Total frames: '11522',
Features: '['observation.images.image', 'language_instruction', 'observation.state', 'action', 'timestamp', 'episode_index', 'frame_index', 'next.reward', 'next.done', 'index', 'task_index']',
})',
実はhugging face datasetの型を踏襲しているので,それを確認することもできる
>>> print(dataset.hf_dataset)
Dataset({
features: ['observation.state', 'action', 'timestamp', 'episode_index', 'frame_index', 'next.reward', 'next.done', 'index', 'task_index'],
num_rows: 11522
})
実はPytorchのデータセットも踏襲しているので,そのやり方に沿って特定のepisodeを取り出すことも可能
1つ目のepisodeのframeのindexを取り出してみる
>>> episode_index = 0
>>> from_idx = dataset.episode_data_index["from"][episode_index].item()
>>> to_idx = dataset.episode_data_index["to"][episode_index].item()
取り出したindexを使ってepisode 1のcamera frameをすべて取得する
>>> camera_key = dataset.meta.camera_keys[0]
>>> frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
各要素はpytorchのTensorの型を持つ
>>> print(type(frames[0]))
<class 'torch.Tensor'>
>>> print(frames[0].shape)
torch.Size([3, 128, 128]) # Tensorなので(channel, height, width)
このshapeはdatasetから取得できる画像のshapeと異なるので注意
>>> pprint(dataset.features[camera_key])
{'dtype': 'video',
'names': ['height', 'width', 'channel'],
'shape': (128, 128, 3),
'video_info': {'has_audio': False,
'video.codec': 'av1',
'video.fps': 5.0,
'video.is_depth_map': False,
'video.pix_fmt': 'yuv420p'}}
>>> print(dataset.features[camera_key]["shape"])
(128, 128, 3) # こちらは(height, width, channel)
How to use
time stampを使って特定のデータを抽出することが可能(但し、timestampの感覚は1/fpsの倍数になっている必要あり)
>>> # fpsの確認
>>> dataset.fps
5
>>> delta_timestamps = {
... # 4つの画像をload : 1 s前, 600 ms前, 200 ms前, 現在の画像
... camera_key: [-1, -0.6, -0.20, 0],
... # 8つの状態ベクトルをload: 1.6 s前, 1 s前, ... 400 ms前, 200 ms前, 現在
... "observation.state": [-1.6, -1, -0.6, 0.4, -0.20, 0],
... # 64個の行動ベクトルをload: 現在, 1 frame先, 2 frames, ... 63 frames先
... "action": [t / dataset.fps for t in range(64)],
... }
delta time stampを使用してDatasetをloadできる
>>> dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
Fetching 4 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 2874.29it/s]
Fetching 166 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 166/166 [00:00<00:00, 726.77it/s]
Resolving data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 597053.95it/s]
>>> print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
dataset[0][camera_key].shape=torch.Size([4, 3, 128, 128])
>>> print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
dataset[0]['observation.state'].shape=torch.Size([6, 7])
>>> print(f"{dataset[0]['action'].shape=}\n") # (64, c)
dataset[0]['action'].shape=torch.Size([64, 8])
Pytorch DataLoaderに移植
>>> for batch in dataloader:
... print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
... print(f"{batch['observation.state'].shape=}") # (32, 5, c)
... print(f"{batch['action'].shape=}") # (32, 64, c)
... break
...
batch[camera_key].shape=torch.Size([32, 4, 3, 128, 128])
batch['observation.state'].shape=torch.Size([32, 6, 7])
batch['action'].shape=torch.Size([32, 64, 8])
以上
目次