1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorchのDatasetで動画データセットを扱えるようにする

Last updated at Posted at 2024-10-30

この記事は

PyTorchで動画のデータセットを扱えるようにします。

PyTorchの Dataset クラスは、データセットをミニバッチで学習するときに前処理やシャッフル、ミニバッチ化をやってくれる便利なクラスです。PyTorchで提供されているデータセットを利用するときは、 torchvision.datasets から直接 Dataset を継承したクラスのインスタンスという形で取得できます。独自のデータセットを準備している場合も、画像データであれば torchvision.datasets.ImageFolder クラスを利用することで気軽に Dataset クラスのインスタンスを取得できます。しかし動画のデータセットの場合は、Datasetクラスを直接利用することはできず、クラスを継承してカスタマイズが必要です。この記事では、独自に用意した動画データに対応したDatasetクラスのカスタマイズ方法を説明します。

Dataset クラスをカスタマイズする

Dataset クラスを継承して、動画データを扱えるようにする VideoDataset クラスを定義します。まずはDatasetクラスの関数うち、オーバーライドが必要な関数の概要を説明します。その次にオーバーライドの実装方法を説明します。

オーバーライドする関数の概要

Dataset クラスからオーバーライドする関数は __init__, __len__, __getitem__ の3つです。ここではオーバーライドした関数でどんな処理が必要なのか概略を説明して、次のセクションで実際のコードを記述します。

Datasetを継承するVideoDatasetクラスの概略
class VideoDataset(Dataset):
    def __init__(self):
        """
        必要なデータを取得するなど初期化処理を行います。
        """

    def __len__(self):
        """
        データセットの総数を返します。
        """

    def __getitem__(self, index):
        """
        indexで指定された位置の動画データをtorch.tensor形式で返します。
        """

__init__はインスタンスの初期化処理を行います。__len__は保持しているデータの総数を返します。__getitem__は、リスト・タプル・辞書のような、いわゆるコンテナ型のクラスで使用される特殊関数です。DatasetクラスはDataLoaderクラスに渡されてデータが1つずつ取り出されます。__getitem__をオーバーライドすることでDataLoaderクラスの処理に対応することができます。

オーバーライドする関数

それでは、VideoDataSet を以下のように実装します。

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset

class VideoDataset(Dataset):
    def __init__(self, paths, labels, clip_length):
        # 動画ファイルパスのリスト
        self.paths = paths
        # 正解ラベルのリストをテンソルに変換
        self.labels = [int(label) for label in labels]
        self.labels = torch.tensor(self.labels)
        # サンプリングするフレームクリップ数
        self.clip_length = clip_length

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        label = self.labels[index]

        # OpenCVで動画ファイルを開く
        cap = cv2.VideoCapture(path)
        # 動画の全フレーム数を取得
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        # 動画の全フレーム数からclip_lengthの数だけ均等なインデックスを取得
        indices = np.linspace(1, frame_count, num=self.clip_length, dtype=int)

        inputs = []
        count = 1
        while True:
            # フレームを順番に取得
            ret, frame = cap.read()
            # 正しく取得できればretにTrueが返される
            if ret:
                # 均等に取得したindicesリスト内のインデックスのときだけフレームを保存
                if count in indices:
                    inputs.append(frame)
            else:
                break
            count += 1

        # 取得したフレームのリストをテンソルに変換
        inputs = torch.tensor(inputs)

        return inputs, label

次に、個別の関数について見ていきます。

__init__関数

__init__関数はシンプルで、引数をインスタンス変数として代入しているだけです。paths, labelsはそれぞれ後述の「ファイル名からラベルを作成」で作成した動画ファイルパスとラベルのリストを渡します。その際、ラベルのリストはPyTorchが読み込めるように 文字列リスト → intリスト → テンソルリスト に変換しています。clip_lengthは、動画全体のフレームから何フレームをサンプリングするかのフレームの数です。この数は動画認識を行う機械学習モデルの入力次元に依存します。

__init__関数
def __init__(self, paths, labels, clip_length):
    # 動画ファイルパスのリスト
    self.paths = paths
    # 正解ラベルのリストをテンソルに変換
    self.labels = [int(label) for label in labels]
    self.labels = torch.tensor(self.labels)
    # サンプリングするフレームクリップ数
    self.clip_length = clip_length

__len__関数

続いて__len__関数もシンプルで、データセットの全数を返しています。

__len__関数
def __len__(self):
    return len(self.paths)

__getitem__関数

__getitem__関数は引数indexに対応するデータを1つ返せばいいだけですが、学習対象の機械学習モデルの入力次元と形式に合わせて動画を加工する必要があります。

__getitem__関数
def __getitem__(self, index):
    path = self.paths[index]
    label = self.labels[index]

    # (1) OpenCVで動画ファイルを開く
    cap = cv2.VideoCapture(path)
    # (2) 動画の全フレーム数を取得
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    # (3) 動画の全フレーム数からclip_lengthの数だけ均等なインデックスを取得
    indices = np.linspace(1, frame_count, num=self.clip_length, dtype=int)

    inputs = []
    count = 1
    while True:
        # (4) フレームを順番に取得
        ret, frame = cap.read()
        # 正しく取得できればretにTrueが返される
        if ret:
            # (5) カウントが均等に取得したindicesリスト内と一致するときだけフレームを保存
            if count in indices:
                inputs.append(frame)
        else:
            break
        count += 1

    # (6) 取得したフレームのリストをテンソルに変換
    inputs = torch.tensor(inputs)

    return inputs, label

まず(1)はOpenCVで動画ファイルを開き、(2)は開いた動画の全フレーム数を取得しています。(3)では、1frame_countのフレームインデックスから、均等な間隔でself.clip_length個のインデックスを抜き出してリストindicesに代入しています。(4)のcap.read()はOpenCVの機能で、関数が成功するとretTrueframeに現在の位置のフレームがそれぞれ返されます。(5)では、取り出しのカウントとindicesが一致するときにそのフレームをリストinputsへ保管し、(6)でテンソルに変換してからラベルとともに関数の戻り値として返されます。

VideoDatasetを使用する

作成したVideoDatasetを使用する方法は、以下に示すようにPyTorchでの通常のデータセットの使用方法と同じです。

データセットの準備

動画データを準備します。動画は10種類(10ラベル)の動作に分類できる全3,200の動画ファイルを用意したとします。以下のような規則でファイル名を付けます。

./dataset/AA_BBB.mp4

AA は動画の分類ラベル(01, 02, ..., 10)、BBB はその分類ラベルでの動画のインデックス(001, 002, 003...)です。

ファイル名からラベルを作成

動画ファイルが保存されている./datasetから、全ての動画ファイル名を読み取って

  • 動画ファイルパスのリスト
  • 正解ラベルのリスト

を作成します。以下のプログラムを実行すると、

  • video_pathsに動画ファイルパスのリスト
  • labels に正解ラベルのリスト

が作成されます。

ファイル名からデータラベルを作成
import os
from pathlib2 import Path

video_dir = './dataset'
video_dir_path = Path(video_dir)

paths = []
labels = []
for video_file in video_dir_path.glob('*.mp4'):
    file_name = video_file.stem
    label_id = file_name[:2]

    full_path = os.path.join(video_dir, video_file.name)
    paths.append(full_path)
    labels.append(label_id)

データセットのロード

準備ができたところで、VideoDatasetを利用してデータセットをロードします。video_pathslabelsは先ほど作成したもの、clip_lengthは機械学習モデルの入力に拠ります。

data_set = VideoDataset(video_paths, labels, clip_length=32)
print(f'{len(data_set)=}')

# 出力
# len(data_set)=3200

出力は用意したデータセットのサイズと一致していて、正しくロードできています。

データローダの取得

作成したVideoDatasetクラスのインスタンスで、データローダを作成します。そのあと、最初のミニバッチの分だけ入力データを取得して、シェイプとラベルを表示しています。

train_loader = DataLoader(data_set, batch_size=10, shuffle=True)

for inputs, labels in train_loader:
    print(f'{inputs.shape=}')
    print(f'labels=')
    break

# 出力
# inputs.shape=torch.Size([10, 32, 224, 224, 3])
# labels=tensor([3, 4, 6, 9, 5, 1, 3, 2, 8, 2])

得られた出力のうち

  • inputs.shape10はミニバッチのサイズ、32はフレーム数、224, 224はフレームのサイズ、3はフレームのチャネル数
  • labelsはミニバッチのサイズ分のラベルのリスト

となっていて、正しくデータが読み込まれています。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?