1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

日本音響学会 学生・若手フォーラムAdvent Calendar 2023

Day 13

ニューラルネットワークで可変長バッチを作りたい!! (PyTorch)

Last updated at Posted at 2023-12-12

音に関わってないやん!解散!

まあ待ってください。
人の話は最後まで聞くものです。

コスパ (タイパ?) 重視のあなたにもいいことがあるはずです。

皆さん、ニューラルネットワークの学習してますか?
様々なタスクでニューラルネットワークが用いられていて、音声認識とか (精一杯の音要素) 画像処理、果てはChatGPTにも使われています。

ある程度慣れてくると自前のタスクで学習したい!とか思うはずです。

しかし、いざ学習をするとこう思ったことはないですか?

「学習おっせえ...2週間とかマジ無理...」

そんな時一つの解決策として「バッチサイズ増やせば?」という声が聞こえてきます。
→「いや、メモリいっぱいですが?」

基本的な画像処理だとまた別の原因の可能性がありますが、音やテキスト情報などは入力データによって長さが異なるものです。

そして通常のニューラルネットワークの実装をしているとバッチサイズを指定する場面があるかと思います。

このバッチサイズ、指定した後は原則変えません。
つまり、10にしたらめちゃくちゃ短いデータを入力してもバッチサイズは10だし、めちゃくちゃ長いデータを入力してもバッチサイズは10になります。

そして「メモリいっぱい」= 「めちゃくちゃ長いデータがギリギリ入るバッチサイズを指定」している、というわけです。

そうです。データが短い時にド余りしてます。

いくつかこの問題の解決法がありますが、ここで使いたいのはバッチサイズを固定、ではなく、どれぐらいの時間フレームを入力するかを指定できるようにしたい、という話です。

こうすればめちゃくちゃ短い時はバッチサイズが100になり長い時は10になるということが可能になりメモリ効率が上がり、学習も速くなる、というわけです。

想定環境

  • PyTorch
  • Python

実装

さて、こちらもいろいろ解決策があることかと思いますが、今回はtorch.utils.data.Dataset, torch.utils.data.Samplerの枠組みをベースに作りましょう。

Datasetは非常に単純なものですが、Samplerはどのように一バッチあたりのデータをサンプルするか、を記述するクラスになります。

どのように長さを取ってくるか、というのはあらかじめ長さを記述したファイルを持っておくなどありますが、今回は一回データをdataloaderで読んで長さを取得しています。

import pandas as pd
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from tqdm import tqdm

class TrainDatasets(Dataset):                                                   
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data_input = self.data_list[idx]
        
        sample = {'data_input': data_input}
        return sample


class LengthsBatchSampler(Sampler):
    """
    LengthsBatchSampler - Sampler for variable batch size.

    Args:
        dataset (torch.nn.dataset)
    """
    def __init__(self, dataset, n_lengths, shuffle=True):
        loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=1)
        lengths_list = []
        # extract lengths 
        for d in loader:
            data_input = d[0]
            lengths_list.append(data_input.shape[0])
        self.lengths_np = np.array(lengths_list)
        
        self.n_lengths = n_lengths
        self.all_indices = self._batch_indices()
        self.shuffle = shuffle

    def _batch_indices(self):
        self.count = 0
        all_indices = []
        while self.count + 1 < len(self.lengths_np):
            indices = []
            max_len = 0
            while self.count < len(self.lengths_np):
                curr_len = self.lengths_np[self.count]
                batch_lengths = max(max_len, curr_len) * (len(indices) + 1)
                if batch_lengths > self.n_lengths or (self.count + 1) > len(self.lengths_np):
                    break
                max_len = max(max_len, curr_len)
                indices.extend([self.count])
                self.count += 1
            all_indices.append(indices)
       
        return all_indices

    def __iter__(self):
        if self.shuffle:
            random.shuffle(self.all_indices)

        for indices in self.all_indices:
            yield indices

    def __len__(self):
        return len(self.all_indices)


def collate_fn(batch):
    data_input = [item['data_input'] for item in batch]
    data_input = torch.nn.utils.rnn.pad_sequence(data_input, batch_first=True, padding_value=0)

    return torch.FloatTensor(data_input)

def make_random_data():
    data_input = []
    for _ in range(100):
        data_len = torch.rand(np.random.randint(1, 100))
        data_input.append(data_len)
        data_sort = sorted(data_input, key=lambda x: len(x), reverse=True)
    return data_sort

if __name__ == '__main__':
    data_sort = make_random_data()

    datasets = TrainDatasets(data_sort)
    sampler = LengthsBatchSampler(datasets, 300, shuffle=True)
    dataloader = DataLoader(datasets, batch_sampler=sampler, num_workers=1, collate_fn=collate_fn)
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    for d in dataloader:
        data_input = d
        print(data_input.shape)

結果

$python dataset_test.py 
torch.Size([3, 79])
torch.Size([3, 83])
torch.Size([3, 77])
torch.Size([3, 80])
torch.Size([4, 75])
torch.Size([4, 71])
torch.Size([3, 94])
torch.Size([12, 24])
torch.Size([3, 98])
torch.Size([5, 60])
torch.Size([17, 15])
torch.Size([5, 53])
torch.Size([9, 32])
torch.Size([3, 85])
torch.Size([8, 36])
torch.Size([5, 52])
torch.Size([7, 40])
torch.Size([3, 87])

0次元目がバッチサイズ、1次元目がそのバッチの長さになっています。
今回は300を指定したため大体合計が長さ300でバッチがまとめられてるはずです。
そして短いフレームはバッチサイズが大きく、長いものはバッチサイズが小さくなっていることが分かります。

問題点

このコードはトイコードなので、エラー処理等が入っていない、自分のデータ用に改善しないといけないなどありますが、実際に使う際には注意しないといけないところが2点あります。

  • あらかじめ長さ順でソートしないといけない。トイコード内で言えば data_sort = sorted(data_input, key=lambda x: len(x), reverse=True)部分
  • LengthsBatchSampler(datasets, 300, shuffle=True)の長さ (300) 部分をデータの一番長いものより短くするとwhileが抜けられない。
  • 1epoch全体の総データ数が大きければ大きいほど長さの取得に時間がかかる

2番目はともかく、1番目は前処理が増えるので注意してください。
3番目は地味に痛い問題で一応あらかじめ長さを取得してそのデータを読み取る形で対応していますが、私も悩んでいるので、もし解決策あるなら教えて (他力本願)

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?