この記事は
PyTorchで動画のデータセットを扱えるようにします。
PyTorchの Dataset クラスは、データセットをミニバッチで学習するときに前処理やシャッフル、ミニバッチ化をやってくれる便利なクラスです。PyTorchで提供されているデータセットを利用するときは、 torchvision.datasets から直接 Dataset を継承したクラスのインスタンスという形で取得できます。独自のデータセットを準備している場合も、画像データであれば torchvision.datasets.ImageFolder クラスを利用することで気軽に Dataset クラスのインスタンスを取得できます。しかし動画のデータセットの場合は、Datasetクラスを直接利用することはできず、クラスを継承してカスタマイズが必要です。この記事では、独自に用意した動画データに対応したDatasetクラスのカスタマイズ方法を説明します。
Dataset クラスをカスタマイズする
Dataset クラスを継承して、動画データを扱えるようにする VideoDataset クラスを定義します。まずはDatasetクラスの関数うち、オーバーライドが必要な関数の概要を説明します。その次にオーバーライドの実装方法を説明します。
オーバーライドする関数の概要
Dataset クラスからオーバーライドする関数は __init__, __len__, __getitem__ の3つです。ここではオーバーライドした関数でどんな処理が必要なのか概略を説明して、次のセクションで実際のコードを記述します。
class VideoDataset(Dataset):
def __init__(self):
"""
必要なデータを取得するなど初期化処理を行います。
"""
def __len__(self):
"""
データセットの総数を返します。
"""
def __getitem__(self, index):
"""
indexで指定された位置の動画データをtorch.tensor形式で返します。
"""
__init__はインスタンスの初期化処理を行います。__len__は保持しているデータの総数を返します。__getitem__は、リスト・タプル・辞書のような、いわゆるコンテナ型のクラスで使用される特殊関数です。DatasetクラスはDataLoaderクラスに渡されてデータが1つずつ取り出されます。__getitem__をオーバーライドすることでDataLoaderクラスの処理に対応することができます。
オーバーライドする関数
それでは、VideoDataSet を以下のように実装します。
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
class VideoDataset(Dataset):
def __init__(self, paths, labels, clip_length):
# 動画ファイルパスのリスト
self.paths = paths
# 正解ラベルのリストをテンソルに変換
self.labels = [int(label) for label in labels]
self.labels = torch.tensor(self.labels)
# サンプリングするフレームクリップ数
self.clip_length = clip_length
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
label = self.labels[index]
# OpenCVで動画ファイルを開く
cap = cv2.VideoCapture(path)
# 動画の全フレーム数を取得
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 動画の全フレーム数からclip_lengthの数だけ均等なインデックスを取得
indices = np.linspace(1, frame_count, num=self.clip_length, dtype=int)
inputs = []
count = 1
while True:
# フレームを順番に取得
ret, frame = cap.read()
# 正しく取得できればretにTrueが返される
if ret:
# 均等に取得したindicesリスト内のインデックスのときだけフレームを保存
if count in indices:
inputs.append(frame)
else:
break
count += 1
# 取得したフレームのリストをテンソルに変換
inputs = torch.tensor(inputs)
return inputs, label
次に、個別の関数について見ていきます。
__init__関数
__init__関数はシンプルで、引数をインスタンス変数として代入しているだけです。paths, labelsはそれぞれ後述の「ファイル名からラベルを作成」で作成した動画ファイルパスとラベルのリストを渡します。その際、ラベルのリストはPyTorchが読み込めるように 文字列リスト → intリスト → テンソルリスト に変換しています。clip_lengthは、動画全体のフレームから何フレームをサンプリングするかのフレームの数です。この数は動画認識を行う機械学習モデルの入力次元に依存します。
def __init__(self, paths, labels, clip_length):
# 動画ファイルパスのリスト
self.paths = paths
# 正解ラベルのリストをテンソルに変換
self.labels = [int(label) for label in labels]
self.labels = torch.tensor(self.labels)
# サンプリングするフレームクリップ数
self.clip_length = clip_length
__len__関数
続いて__len__関数もシンプルで、データセットの全数を返しています。
def __len__(self):
return len(self.paths)
__getitem__関数
__getitem__関数は引数indexに対応するデータを1つ返せばいいだけですが、学習対象の機械学習モデルの入力次元と形式に合わせて動画を加工する必要があります。
def __getitem__(self, index):
path = self.paths[index]
label = self.labels[index]
# (1) OpenCVで動画ファイルを開く
cap = cv2.VideoCapture(path)
# (2) 動画の全フレーム数を取得
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# (3) 動画の全フレーム数からclip_lengthの数だけ均等なインデックスを取得
indices = np.linspace(1, frame_count, num=self.clip_length, dtype=int)
inputs = []
count = 1
while True:
# (4) フレームを順番に取得
ret, frame = cap.read()
# 正しく取得できればretにTrueが返される
if ret:
# (5) カウントが均等に取得したindicesリスト内と一致するときだけフレームを保存
if count in indices:
inputs.append(frame)
else:
break
count += 1
# (6) 取得したフレームのリストをテンソルに変換
inputs = torch.tensor(inputs)
return inputs, label
まず(1)はOpenCVで動画ファイルを開き、(2)は開いた動画の全フレーム数を取得しています。(3)では、1~frame_countのフレームインデックスから、均等な間隔でself.clip_length個のインデックスを抜き出してリストindicesに代入しています。(4)のcap.read()はOpenCVの機能で、関数が成功するとretにTrue、frameに現在の位置のフレームがそれぞれ返されます。(5)では、取り出しのカウントとindicesが一致するときにそのフレームをリストinputsへ保管し、(6)でテンソルに変換してからラベルとともに関数の戻り値として返されます。
VideoDatasetを使用する
作成したVideoDatasetを使用する方法は、以下に示すようにPyTorchでの通常のデータセットの使用方法と同じです。
データセットの準備
動画データを準備します。動画は10種類(10ラベル)の動作に分類できる全3,200の動画ファイルを用意したとします。以下のような規則でファイル名を付けます。
./dataset/AA_BBB.mp4
AA は動画の分類ラベル(01, 02, ..., 10)、BBB はその分類ラベルでの動画のインデックス(001, 002, 003...)です。
ファイル名からラベルを作成
動画ファイルが保存されている./datasetから、全ての動画ファイル名を読み取って
- 動画ファイルパスのリスト
- 正解ラベルのリスト
を作成します。以下のプログラムを実行すると、
-
video_pathsに動画ファイルパスのリスト -
labelsに正解ラベルのリスト
が作成されます。
import os
from pathlib2 import Path
video_dir = './dataset'
video_dir_path = Path(video_dir)
paths = []
labels = []
for video_file in video_dir_path.glob('*.mp4'):
file_name = video_file.stem
label_id = file_name[:2]
full_path = os.path.join(video_dir, video_file.name)
paths.append(full_path)
labels.append(label_id)
データセットのロード
準備ができたところで、VideoDatasetを利用してデータセットをロードします。video_pathsとlabelsは先ほど作成したもの、clip_lengthは機械学習モデルの入力に拠ります。
data_set = VideoDataset(video_paths, labels, clip_length=32)
print(f'{len(data_set)=}')
# 出力
# len(data_set)=3200
出力は用意したデータセットのサイズと一致していて、正しくロードできています。
データローダの取得
作成したVideoDatasetクラスのインスタンスで、データローダを作成します。そのあと、最初のミニバッチの分だけ入力データを取得して、シェイプとラベルを表示しています。
train_loader = DataLoader(data_set, batch_size=10, shuffle=True)
for inputs, labels in train_loader:
print(f'{inputs.shape=}')
print(f'labels=')
break
# 出力
# inputs.shape=torch.Size([10, 32, 224, 224, 3])
# labels=tensor([3, 4, 6, 9, 5, 1, 3, 2, 8, 2])
得られた出力のうち
-
inputs.shapeの10はミニバッチのサイズ、32はフレーム数、224, 224はフレームのサイズ、3はフレームのチャネル数 -
labelsはミニバッチのサイズ分のラベルのリスト
となっていて、正しくデータが読み込まれています。