この記事は
PyTorchで動画のデータセットを扱えるようにします。
PyTorchの Dataset
クラスは、データセットをミニバッチで学習するときに前処理やシャッフル、ミニバッチ化をやってくれる便利なクラスです。PyTorchで提供されているデータセットを利用するときは、 torchvision.datasets
から直接 Dataset
を継承したクラスのインスタンスという形で取得できます。独自のデータセットを準備している場合も、画像データであれば torchvision.datasets.ImageFolder
クラスを利用することで気軽に Dataset
クラスのインスタンスを取得できます。しかし動画のデータセットの場合は、Dataset
クラスを直接利用することはできず、クラスを継承してカスタマイズが必要です。この記事では、独自に用意した動画データに対応したDataset
クラスのカスタマイズ方法を説明します。
Dataset
クラスをカスタマイズする
Dataset
クラスを継承して、動画データを扱えるようにする VideoDataset
クラスを定義します。まずはDataset
クラスの関数うち、オーバーライドが必要な関数の概要を説明します。その次にオーバーライドの実装方法を説明します。
オーバーライドする関数の概要
Dataset
クラスからオーバーライドする関数は __init__
, __len__
, __getitem__
の3つです。ここではオーバーライドした関数でどんな処理が必要なのか概略を説明して、次のセクションで実際のコードを記述します。
class VideoDataset(Dataset):
def __init__(self):
"""
必要なデータを取得するなど初期化処理を行います。
"""
def __len__(self):
"""
データセットの総数を返します。
"""
def __getitem__(self, index):
"""
indexで指定された位置の動画データをtorch.tensor形式で返します。
"""
__init__
はインスタンスの初期化処理を行います。__len__
は保持しているデータの総数を返します。__getitem__
は、リスト・タプル・辞書のような、いわゆるコンテナ型のクラスで使用される特殊関数です。Dataset
クラスはDataLoader
クラスに渡されてデータが1つずつ取り出されます。__getitem__
をオーバーライドすることでDataLoader
クラスの処理に対応することができます。
オーバーライドする関数
それでは、VideoDataSet
を以下のように実装します。
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
class VideoDataset(Dataset):
def __init__(self, paths, labels, clip_length):
# 動画ファイルパスのリスト
self.paths = paths
# 正解ラベルのリストをテンソルに変換
self.labels = [int(label) for label in labels]
self.labels = torch.tensor(self.labels)
# サンプリングするフレームクリップ数
self.clip_length = clip_length
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
label = self.labels[index]
# OpenCVで動画ファイルを開く
cap = cv2.VideoCapture(path)
# 動画の全フレーム数を取得
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 動画の全フレーム数からclip_lengthの数だけ均等なインデックスを取得
indices = np.linspace(1, frame_count, num=self.clip_length, dtype=int)
inputs = []
count = 1
while True:
# フレームを順番に取得
ret, frame = cap.read()
# 正しく取得できればretにTrueが返される
if ret:
# 均等に取得したindicesリスト内のインデックスのときだけフレームを保存
if count in indices:
inputs.append(frame)
else:
break
count += 1
# 取得したフレームのリストをテンソルに変換
inputs = torch.tensor(inputs)
return inputs, label
次に、個別の関数について見ていきます。
__init__
関数
__init__
関数はシンプルで、引数をインスタンス変数として代入しているだけです。paths
, labels
はそれぞれ後述の「ファイル名からラベルを作成」で作成した動画ファイルパスとラベルのリストを渡します。その際、ラベルのリストはPyTorchが読み込めるように 文字列リスト → intリスト → テンソルリスト に変換しています。clip_length
は、動画全体のフレームから何フレームをサンプリングするかのフレームの数です。この数は動画認識を行う機械学習モデルの入力次元に依存します。
def __init__(self, paths, labels, clip_length):
# 動画ファイルパスのリスト
self.paths = paths
# 正解ラベルのリストをテンソルに変換
self.labels = [int(label) for label in labels]
self.labels = torch.tensor(self.labels)
# サンプリングするフレームクリップ数
self.clip_length = clip_length
__len__
関数
続いて__len__
関数もシンプルで、データセットの全数を返しています。
def __len__(self):
return len(self.paths)
__getitem__
関数
__getitem__
関数は引数index
に対応するデータを1つ返せばいいだけですが、学習対象の機械学習モデルの入力次元と形式に合わせて動画を加工する必要があります。
def __getitem__(self, index):
path = self.paths[index]
label = self.labels[index]
# (1) OpenCVで動画ファイルを開く
cap = cv2.VideoCapture(path)
# (2) 動画の全フレーム数を取得
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# (3) 動画の全フレーム数からclip_lengthの数だけ均等なインデックスを取得
indices = np.linspace(1, frame_count, num=self.clip_length, dtype=int)
inputs = []
count = 1
while True:
# (4) フレームを順番に取得
ret, frame = cap.read()
# 正しく取得できればretにTrueが返される
if ret:
# (5) カウントが均等に取得したindicesリスト内と一致するときだけフレームを保存
if count in indices:
inputs.append(frame)
else:
break
count += 1
# (6) 取得したフレームのリストをテンソルに変換
inputs = torch.tensor(inputs)
return inputs, label
まず(1)はOpenCVで動画ファイルを開き、(2)は開いた動画の全フレーム数を取得しています。(3)では、1
~frame_count
のフレームインデックスから、均等な間隔でself.clip_length
個のインデックスを抜き出してリストindices
に代入しています。(4)のcap.read()
はOpenCVの機能で、関数が成功するとret
にTrue
、frame
に現在の位置のフレームがそれぞれ返されます。(5)では、取り出しのカウントとindices
が一致するときにそのフレームをリストinputs
へ保管し、(6)でテンソルに変換してからラベルとともに関数の戻り値として返されます。
VideoDataset
を使用する
作成したVideoDataset
を使用する方法は、以下に示すようにPyTorchでの通常のデータセットの使用方法と同じです。
データセットの準備
動画データを準備します。動画は10種類(10ラベル)の動作に分類できる全3,200の動画ファイルを用意したとします。以下のような規則でファイル名を付けます。
./dataset/AA_BBB.mp4
AA
は動画の分類ラベル(01, 02, ..., 10)、BBB
はその分類ラベルでの動画のインデックス(001, 002, 003...)です。
ファイル名からラベルを作成
動画ファイルが保存されている./dataset
から、全ての動画ファイル名を読み取って
- 動画ファイルパスのリスト
- 正解ラベルのリスト
を作成します。以下のプログラムを実行すると、
-
video_paths
に動画ファイルパスのリスト -
labels
に正解ラベルのリスト
が作成されます。
import os
from pathlib2 import Path
video_dir = './dataset'
video_dir_path = Path(video_dir)
paths = []
labels = []
for video_file in video_dir_path.glob('*.mp4'):
file_name = video_file.stem
label_id = file_name[:2]
full_path = os.path.join(video_dir, video_file.name)
paths.append(full_path)
labels.append(label_id)
データセットのロード
準備ができたところで、VideoDataset
を利用してデータセットをロードします。video_paths
とlabels
は先ほど作成したもの、clip_length
は機械学習モデルの入力に拠ります。
data_set = VideoDataset(video_paths, labels, clip_length=32)
print(f'{len(data_set)=}')
# 出力
# len(data_set)=3200
出力は用意したデータセットのサイズと一致していて、正しくロードできています。
データローダの取得
作成したVideoDataset
クラスのインスタンスで、データローダを作成します。そのあと、最初のミニバッチの分だけ入力データを取得して、シェイプとラベルを表示しています。
train_loader = DataLoader(data_set, batch_size=10, shuffle=True)
for inputs, labels in train_loader:
print(f'{inputs.shape=}')
print(f'labels=')
break
# 出力
# inputs.shape=torch.Size([10, 32, 224, 224, 3])
# labels=tensor([3, 4, 6, 9, 5, 1, 3, 2, 8, 2])
得られた出力のうち
-
inputs.shape
の10
はミニバッチのサイズ、32
はフレーム数、224, 224
はフレームのサイズ、3
はフレームのチャネル数 -
labels
はミニバッチのサイズ分のラベルのリスト
となっていて、正しくデータが読み込まれています。