はじめに
whisperという音声認識モデルを使っていて、出力に対するスコアを得る方法がわからなかったので調べました。
※本記事は現在更新中です。
whisperとは
whisperとはOpenAIが公開している大規模音声認識モデルです。
参考までに、whisperの日本語まとめページを載せておきます。
スコアとは
ここでは「スコア」を、音声認識モデル$M$が出力する、音声$x$のテキスト$t$に対する損失関数値$L(M,x,t)$の$-1$倍と定義します。例えば、CTCモデルでは音声$M$が音声$x$をテキスト$t$と認識する確率$P(M,x,t)$に対して、
Score(M,x,t)=-L(M,x,t)=\log P(M,x,t)
と定義します。スコアとは、言うなれば「予測にどのくらい自身があるか」を数値化したものです。
ここで、whisperモデルでスコアを求めようとすると、1つ壁にぶち当たります。whisperモデルはCTC型ではなく自己回帰型であり、任意の$t$に対してスコアを求めようとすると一工夫必要になる、ということです。例えば、whisperが「こんにちは」と認識する音声に対して「おはようございます」という予測に対するスコアを計算するためにはコードをいじる必要があります。(文字列の長さが違うため)
スコアを求める方法(ハンズオン)
音声はJSUTコーパス501「綾が完璧なドイツ語を話すのは少しも不思議でない。」を使用します。(ファイルパス ./data/BASIC5000_0501.wav)
whisperの設計上、モデルの予測テキスト$M(x)$に対するスコア$Score(M, x, M(x))$は容易に求められますが、一般のテキスト$t$に対するスコア$Score(M,x,t)$を求めるのは容易ではありません。
注) 本記事ではwhisperのインストール方法は掲載していません。他のページでご確認下さい。
モデルの予測テキストM(x)に対するスコアScore(M, x, M(x))の求め方
logitの取得
import whisper
from whisper.decoding import GreedyDecoder
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
seq_logits = []
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
seq_logits.append(logits)
if temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
GreedyDecoder.update = update
model = whisper.load_model("large")
result = model.transcribe("./data/BASIC5000_0501.wav")
実行後、seq_logitsに各トークンのロジットが格納されています。
print(result["text"])
print(seq_logits)
"""
綾が完璧なドイツ語を話すのは少しも不思議でない。
[tensor([[-inf, -inf, -inf, ..., -inf, -inf, -inf]], device='cuda:0'), tensor([[4.5117, -inf, -inf, ..., -inf, -inf, -inf]],
device='cuda:0'), tensor([[-0.2603, -inf, -inf, ..., 0.0032, 0.1667, -1.6484]],
device='cuda:0'), tensor([[ 4.5898, -inf, -inf, ..., 1.0732, 1.8193, -0.7383]],
device='cuda:0'), tensor([[11.7500, -inf, -inf, ..., 6.3633, 5.8086, 4.7891]],
device='cuda:0'), tensor([[9.2812, -inf, -inf, ..., 5.6094, 6.0938, 5.0898]],
device='cuda:0'), tensor([[ 3.9414, -inf, -inf, ..., -0.2452, 0.4495, -4.1992]],
device='cuda:0'), tensor([[17.2656, -inf, -inf, ..., 11.5781, 11.5312, 10.3359]],
device='cuda:0'), tensor([[8.8281, -inf, -inf, ..., 4.1797, 4.8711, 2.9160]],
device='cuda:0'), tensor([[11.7422, -inf, -inf, ..., 6.7734, 7.0664, 5.5391]],
device='cuda:0'), tensor([[9.8203, -inf, -inf, ..., 6.4727, 6.9141, 3.8008]],
device='cuda:0'), tensor([[12.5000, -inf, -inf, ..., 8.5234, 7.8867, 7.5273]],
device='cuda:0'), tensor([[15.3594, -inf, -inf, ..., 10.5703, 10.2891, 10.3750]],
device='cuda:0'), tensor([[5.5664, -inf, -inf, ..., 1.9092, 1.8076, 1.6914]],
device='cuda:0'), tensor([[10.6250, -inf, -inf, ..., 8.0859, 7.5547, 7.8711]],
device='cuda:0'), tensor([[11.7109, -inf, -inf, ..., 8.7344, 8.4922, 8.1875]],
device='cuda:0'), tensor([[ 5.7656, -inf, -inf, ..., 1.7773, 2.5078, -0.5127]],
device='cuda:0'), tensor([[7.0781, -inf, -inf, ..., 5.0469, 4.4102, 4.3359]],
device='cuda:0'), tensor([[6.1992, -inf, -inf, ..., 4.1367, 4.4727, 3.5293]],
device='cuda:0'), tensor([[7.4062, -inf, -inf, ..., 2.6973, 3.5879, 1.2852]],
device='cuda:0'), tensor([[3.4043, -inf, -inf, ..., 1.1250, 2.0078, 1.3047]],
device='cuda:0'), tensor([[10.9453, -inf, -inf, ..., 8.7188, 9.0938, 8.2344]],
device='cuda:0'), tensor([[ 3.3457, -inf, -inf, ..., 0.6387, 0.7017, -1.9385]],
device='cuda:0'), tensor([[3.3867, -inf, -inf, ..., 1.2852, 1.1562, 1.1045]],
device='cuda:0'), tensor([[ 8.1250, -inf, -inf, ..., -1.2490, -0.4404, -2.5156]],
device='cuda:0'), tensor([[ -inf, -inf, -inf, ..., -1.4873, -0.7993, -2.6172]],
device='cuda:0'), tensor([[ -inf, -inf, -inf, ..., 11.1172, 11.5078, 11.4297]],
device='cuda:0')]
"""
logitからスコアへの変換
標準シグモイド関数を適用すればlogitをスコア(確率)に変換できます。
import numpy as np
def std_sigmoid(logit):
return 1/(1+np.exp(-logit))
seq_probs = []
for i in range(len(seq_logits)):
print("idx: {0}, logit: {1:.2f}, prob: {2:.2f}".format(
torch.argmax(seq_logits[i]).item(),
seq_logits[i].max().item(),
std_sigmoid(seq_logits[i].max().item())
))
seq_probs.append(std_sigmoid(seq_logits[i].max().item()))
"""
idx: 50364, logit: 20.109375, prob: 0.9999999981523939
idx: 9261, logit: 12.015625, prob: 0.9999939510814195
idx: 122, logit: 18.6875, prob: 0.9999999923418853
idx: 5142, logit: 14.234375, prob: 0.9999993422070742
idx: 14128, logit: 23.453125, prob: 0.9999999999347715
idx: 40063, logit: 24.53125, prob: 0.9999999999778071
idx: 100, logit: 24.484375, prob: 0.9999999999767422
idx: 3203, logit: 31.453125, prob: 0.999999999999978
idx: 11195, logit: 20.859375, prob: 0.9999999991272526
idx: 8040, logit: 25.875, prob: 0.9999999999942106
idx: 39406, logit: 26.328125, prob: 0.99999999999632
idx: 31348, logit: 28.078125, prob: 0.9999999999993605
idx: 5998, logit: 31.171875, prob: 0.9999999999999709
idx: 11103, logit: 19.484375, prob: 0.9999999965482171
idx: 2659, logit: 28.875, prob: 0.9999999999997118
idx: 35662, logit: 27.578125, prob: 0.9999999999989457
idx: 15686, logit: 17.875, prob: 0.9999999827421723
idx: 2849, logit: 25.140625, prob: 0.9999999999879339
idx: 4801, logit: 22.203125, prob: 0.9999999997723303
idx: 1960, logit: 20.1875, prob: 0.9999999982912435
idx: 8870, logit: 20.65625, prob: 0.9999999989306876
idx: 24686, logit: 31.34375, prob: 0.9999999999999756
idx: 2474, logit: 18.25, prob: 0.99999998813888
idx: 9311, logit: 19.625, prob: 0.9999999970010391
idx: 1543, logit: 13.234375, prob: 0.9999982119354642
idx: 50564, logit: 10.8515625, prob: 0.9999806260634916
idx: 50257, logit: 27.9375, prob: 0.9999999999992639
"""
実行後、seq_probsに各トークンのスコア(確率)の最大値が格納されています。
これをすべて掛けた値が、モデルが音声に対してテキストを出力する確率といえます。
np.prod(np.array(seq_probs)) # 0.9999720823804599
結論として、whisper_largeモデルは音声BASIC5000_0501.wavに対して「綾が完璧なドイツ語を話すのは少しも不思議でない。」と出力し、そのスコアは0.9999720823804599と、ほぼ1.0であることがわかります。
一般のテキストtに対するスコアScore(M, x, t)の求め方
ライブラリのインポート
import whisper
from whisper.decoding import GreedyDecoder
from whisper.decoding import DecodingTask
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
import scipy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
モジュールの書き換え
GreedyDecoderクラスのupdateメソッドを更新します。
from typing import Tuple
seq_logits = []
seq_tokens = []
seq_current_logprobs = []
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
# 追加
seq_logits.append(logits)
seq_tokens.append(tokens)
seq_current_logprobs.append(current_logprobs)
# ここまで
return tokens, completed
GreedyDecoder.update = update
続いて、_main_loopメソッドを更新します。
from whisper.tokenizer import get_tokenizer
import numpy as np
tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")
# text = "綾が完璧なドイツ語を話すのは少しも不思議でない。"
text = "アラン君は、運よく税理士試験に合格しました。"
token_ids = tokenizer.encode(text)
token_ids = [50258, 50266, 50359, 50364] + token_ids + [50564]
token_ids_list = [[[50258, 50266, 50359] + [j for j in token_ids[3:i]]] for i in range(3, len(token_ids)+2)]
token_ids_list = [torch.tensor(token_idx).to(device) for token_idx in token_ids_list]
seq_tokens2 = []
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
seq_tokens2.append(tokens)
for i in range(
# self.sample_len
len(token_ids_list)
):
# この部分でtokensを書き換える
tokens = token_ids_list[i]
# ここまで
logits = self.inference.logits(tokens, audio_features)
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
seq_tokens2.append(tokens)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs
DecodingTask._main_loop = _main_loop
textを一般のテキストにしています。
logits = self.inference.logits(tokens, audio_features) の部分で推論が行われています。audio_featuresは常に同じ入力です。tokensはサイズ(1,n=3以上)となります。
- tokens[0][0],...,tokens[0][2]:$SpecialToken$
- token[0][3]:$SOS$
- token[0][4],...,token[0][n-2]:テキスト
- token[0][n-1]:$EOS$
スコアの計算
model = whisper.load_model("large", device)
audio_path = './output/adv_BASA7BA.wav'
sr, audio = scipy.io.wavfile.read(audio_path)
result = model.transcribe(audio, language="ja")
# デコード
from whisper.tokenizer import get_tokenizer
pred = torch.argmax(torch.stack(seq_logits).squeeze(), dim=-1)
tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")
token_ids = pred
decoded_word = tokenizer.decode(token_ids[:-1])
decoded_word
# スコアの準備
text = "アラン君は、運よく税理士試験に合格しました。"
token_ids = tokenizer.encode(text)
token_ids = [50258, 50266, 50359, 50364] + token_ids + [50564]
token_ids = torch.tensor(token_ids)
seq_logprobs = torch.stack(seq_logits).squeeze().log_softmax(dim=-1).max(dim=-1).values.squeeze()
# スコア
torch.sum(seq_logprobs[:-1])
確率は各トークンの予測$P(t_0|f, SOS), P(t_1|f, SOS, t_0), P(t_2|f, SOS, t_0, t_1), ..., P(EOS|f, SOS, t_0, ..., t_n)$の総積であらわされます。スコアは確率に対数をとったものなので$-\log P(t_0|f, SOS)P(t_1|f, SOS, t_0)P(t_2|f, SOS, t_0, t_1)P(EOS|f, SOS, t_0, ..., t_n)=\sum_{i=0}^{n} P(EOS|f, SOS, t_0, ..., t_i)$とあらわすことができます。
ここで、終了文字を入力したときの予測をスコアに影響させないためにseq_logprobの末尾を捨てています。
Appendix. seq_logitsを用いたデコード
tokenizerを用いることでデコードできます。
from whisper.tokenizer import get_tokenizer, Tokenizer
tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")
k = 1
topk_token_ids = torch.stack([torch.argsort(logits)[0][-k:] for logits in seq_logits]).cpu().T.numpy()
topk_decoded_words = [tokenizer.decode(token_ids) for token_ids in topk_token_ids[::-1]]
topk_decoded_words # ['綾が完璧なドイツ語を話すのは少しも不思議でない。<|endoftext|>']
実行後、topk_decode_wordsに出力テキストが格納されています。確かに、出力は正解テキストと一致しています。
終わりに
いくら探してもwhisperには直接logitやスコアを得るための関数が見つからなかったので、自力で構築しました。もしそのような関数があれば、お知らせいただけますと幸いです。
本記事の作成にあたり、以下の記事を参考にしました。