import h5py
import torch
from torch.utils.data import Dataset
class H5Dataset(Dataset):
def __init__(self, h5_file, preload_factor, transform=None):
self.h5_file = h5_file
self.preload_factor = preload_factor
self.transform = transform
def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return (len(f['features']) - 1) // self.preload_factor + 1
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
start = idx * self.preload_factor
end = start + self.preload_factor
features = f['features'][start:end]
labels = f['labels'][start:end]
if self.transform:
features = self.transform(features)
return torch.tensor(features), torch.tensor(labels)
def collate_fn(batch):
return batch[0]
h5_file = 'your_h5_file_path.h5'
batch_size = 1000
preload_factor = 1000000
dataset = H5Dataset(h5_file, preload_factor=preload_factor)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
for batch in dataloader:
# バッチの処理
for i in range(0, batch.shape[0], batch_size):
mini_batch = batch[i:i+batch_size]
# ここで深層学習モデルの学習を行う