はじめに
はじめまして、株式会社IGSAでAIエンジニアをしているzeronishiです。
AIモデルを学習する際には、元のデータを訓練データ、検証データ、テストデータの、3種類のデータに分けて使用します。流れとしては、訓練データでモデルの学習を行い、学習後のモデルと検証データを使ってハイパーパラメータを調整します。最後にテストデータを使って調整後のAIモデルの汎化性能を評価します。
通常このような流れでAIモデルを開発しますが、もしリークが起きていると、想定している性能と実際の性能が異なるという問題が発生します。この記事では自分への戒めも兼ねて、ありがちなリークを何点かご紹介します。
基本的なリーク
最も基本的なリークは、学習データとテストデータの両方に同一のデータが含まれているものです。モデルは学習を通してそのデータへの性能を高めるので、当然テストデータを使った評価時には、実運用時よりも高い性能を示します。
予測AIで未来の情報を使ってしまう
予測AIである時点の、何等かの値を予測する際には、その時点より前の過去情報の全てを利用することができます。逆に言えば、その時点より先の未来情報は、実運用時には知りえない情報であるため利用できません。
Attention Decoderを用いた生成モデルの学習などでは、下三角行列のマスクをかけることで未来情報のリークを防いだりしています。
話者のリーク
音声系のタスクを想定して「話者」としていますが、ある元データを切り分けて利用するような場合に共通する注意すべきリークだと思います。
音声タスクによっては、前処理として数分から数時間の長時間元音声を、数秒のセグメントに切り分けます。これにより1人の話者につき複数の音声セグメントが得られるわけですが、これが学習データとテストデータに分かれてしまうと、話者性のリークが起きてしまいます。特に音声分類タスクでは、発話内容などではなく、話者性から判別してしまい、不当に高い精度の結果が得られるという問題が起こります。
使用する全データをシャッフルしてデータを分けるのではなく、データ間に共通するものが無いか注意する必要があります。
JVSコーパスという100人の話者の音声が収録さているコーパスを使って性別判定AIを学習させ、簡単な話者リークの例を説明します。
以下がJVSコーパスを読み取るDatasetクラスです。jvs001話者の音声20件をテストデータとして使用します。引数leakがFalseの場合、学習データにjvs001話者のテストデータ以外の音声は使用しませんが、Trueの場合は使用します。これらを比較することで、話者リークによりどの程度性能に差が出るのかを検証します。
特徴量にはxvectorを使用します。
import os
import numpy as np
import pandas as pd
import tqdm
import librosa
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset, random_split
from torch.utils.data.dataset import Dataset
from scipy.io import wavfile
from torchaudio.compliance import kaldi
class JVSDataset(Dataset):
def __init__(self, partition="train", leak=False):
test_data = [f"espnet/egs2/jvs/tts1/downloads/jvs_ver1/jvs001/parallel100/wav24kHz16bit/VOICEACTRESS100_{str(spk_num).zfill(3)}.wav" for spk_num in range(1, 40)]
self.meta_df = pd.read_table(f"espnet/egs2/jvs/tts1/downloads/jvs_ver1/gender_f0range.txt", sep=' ')
self.model = torch.hub.load("sarulab-speech/xvector_jtubespeech", "xvector", trust_repo=True)
self.wav_paths = []
self.labels = []
if partition=="train":
if leak:
for spk_num in range(1, 101):
spk_num = str(spk_num).zfill(3)
dir_name = f"espnet/egs2/jvs/tts1/downloads/jvs_ver1/jvs{spk_num}/parallel100/wav24kHz16bit"
for path in os.listdir(dir_name):
path = os.path.join(dir_name, path)
if path not in test_data:
self.wav_paths.append(path)
self.labels.append(self.meta_df[self.meta_df["speaker"]==f"jvs{spk_num}"]["Male_or_Female"])
else:
for spk_num in range(2, 101):
dir_name = f"espnet/egs2/jvs/tts1/downloads/jvs_ver1/jvs{str(spk_num).zfill(3)}/parallel100/wav24kHz16bit"
for path in os.listdir(dir_name):
path = os.path.join(dir_name, path)
self.wav_paths.append(path)
self.labels.append(self.meta_df[self.meta_df["speaker"]==f"jvs{spk_num}"]["Male_or_Female"])
else:
self.wav_paths = test_data
self.labels = [self.meta_df[self.meta_df["speaker"]=="jvs001"]["Male_or_Female"] for _ in range(1, 40)]
self.labels = [1 if str(gender.values[0])=="M" else 0 for gender in self.labels]
def __len__(self):
return len(self.labels)
def extract_features(self, x):
# extract mfcc
wav = torch.from_numpy(x.astype(np.float32)).unsqueeze(0)
mfcc = kaldi.mfcc(wav, num_ceps=24, num_mel_bins=24) # [1, T, 24]
mfcc = mfcc.unsqueeze(0)
# extract xvector
xvector = self.model.vectorize(mfcc) # (1, 512)
xvector = xvector.to("cpu").detach().numpy().copy()[0]
return xvector
def __getitem__(self, idx):
_, wav = wavfile.read(self.wav_paths[idx])
xvector = self.extract_features(wav)
return xvector, self.labels[idx]
分類器には2層の全結合層を使用します。
class LinearClassification(nn.Module):
def __init__(self):
super().__init__()
self.classification_net = nn.Sequential(nn.Linear(512, 256), nn.Linear(256, 2))
def forward(self, x):
x = self.classification_net(x)
return x
リークあり、なしの両方で2値分類モデルを学習させます。バッチサイズは8とし、5epoch回しました。
for leak in [True, False]:
batch_size = 8
generator = torch.Generator().manual_seed(42)
train_set, val_set = random_split(JVSDataset("train", leak), [0.8, 0.2], generator=generator)
test_set = JVSDataset("test", leak)
train_dataloader = DataLoader(
dataset = train_set,
batch_size = batch_size,
)
val_dataloader = DataLoader(
dataset = val_set,
batch_size = batch_size,
)
test_dataloader = DataLoader(
dataset = test_set,
batch_size = batch_size,
)
model = LinearClassification()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(5):
for X, y in tqdm.tqdm(train_dataloader):
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
loss, correct = 0, 0
size = len(test_set)
for X, y in test_dataloader:
pred = model(X)
pred = nn.functional.softmax(pred, dim=0)
loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
print(f"leak {leak}, epoch {epoch}, Test Accuracy: {(100*correct/size):>0.1f}%")
こちらのコードを実行した結果、私の環境ではリークありの場合はAccuracy=59%、リークなしの場合はAccuracy=41%という結果になりました。話者リークにより18ptもの精度の差が出てしまうことがわかりました。
事前学習のリーク
こちらも音声で例を挙げますが、音声認識タスクの学習には、音声と対応する書き起こしテキストのペアデータが必要となります。しかし書き起こしテキストを用意するコストは膨大であるという問題があります。そこで、BERTのような穴埋めタスクを使った事前学習により、大量の音声データのみで有意な音声表現を抽出できる自己教師あり学習モデルを利用します。この事前学習には膨大な計算リソースが必要となるため、一般にオープンに公開されているモデルを利用するのですが、whisperなど、中には事前学習に用いた音声データの詳細が公開されていないモデルもあります。
ニッチな例になりますが、whisperの追加事前学習によるドメイン適用の有効性の検証などでは、whisperの事前学習データにそのドメインの音声が含まれている可能性があるため、より適切なモデルに変更すべき場合などがあります。
まとめ
AIモデル開発時では常に注意が必要なリークについてご紹介しました。
基本的なリークはよく知られていると思いますが、タスクによっては上記以外のリークもあるかもしれません。想定よりも明らかに高い性能が出た場合などはまずリークを疑うことが重要だと思います。
参考文献
機械学習におけるリークとは
機械学習の落とし穴 リーク問題について
Whisperが登場
sarulab-speech/xvector_jtubespeech
JVS (Japanese versatile speech) corpus
IGSAについて
IGSAは、社会を温かく柔らかく持続的に支えるAIシステムにより、持続可能な幸せを目指す、東京大学松尾・岩澤研究室発のAIカンパニーです。
脳の健康管理アプリ「はなしてね」や、中古品の画像解析SaaS「スグトリ」などのAIプロダクト提供に加え、潜在的な課題に対し柔軟な開発支援を行うパートナー事業を展開。センシングAI技術を活用した状態の定量化と分析により、人の意思決定をサポートしています。
