1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

whisperモデルの出力に対するスコア(確率)を求める方法

Last updated at Posted at 2023-09-12

はじめに

whisperという音声認識モデルを使っていて、出力に対するスコアを得る方法がわからなかったので調べました。
※本記事は現在更新中です。

whisperとは

whisperとはOpenAIが公開している大規模音声認識モデルです。

参考までに、whisperの日本語まとめページを載せておきます。

スコアとは

 ここでは「スコア」を、音声認識モデル$M$が出力する、音声$x$のテキスト$t$に対する損失関数値$L(M,x,t)$の$-1$倍と定義します。例えば、CTCモデルでは音声$M$が音声$x$をテキスト$t$と認識する確率$P(M,x,t)$に対して、

Score(M,x,t)=-L(M,x,t)=\log P(M,x,t)

と定義します。スコアとは、言うなれば「予測にどのくらい自身があるか」を数値化したものです。
 ここで、whisperモデルでスコアを求めようとすると、1つ壁にぶち当たります。whisperモデルはCTC型ではなく自己回帰型であり、任意の$t$に対してスコアを求めようとすると一工夫必要になる、ということです。例えば、whisperが「こんにちは」と認識する音声に対して「おはようございます」という予測に対するスコアを計算するためにはコードをいじる必要があります。(文字列の長さが違うため)

スコアを求める方法(ハンズオン)

音声はJSUTコーパス501「綾が完璧なドイツ語を話すのは少しも不思議でない。」を使用します。(ファイルパス ./data/BASIC5000_0501.wav)

whisperの設計上、モデルの予測テキスト$M(x)$に対するスコア$Score(M, x, M(x))$は容易に求められますが、一般のテキスト$t$に対するスコア$Score(M,x,t)$を求めるのは容易ではありません。
注) 本記事ではwhisperのインストール方法は掲載していません。他のページでご確認下さい。

モデルの予測テキストM(x)に対するスコアScore(M, x, M(x))の求め方

logitの取得

import whisper
from whisper.decoding import GreedyDecoder
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical

seq_logits = []
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
    temperature = self.temperature
    seq_logits.append(logits)
    if temperature == 0:
        next_tokens = logits.argmax(dim=-1)
    else:
        next_tokens = Categorical(logits=logits / temperature).sample()

    logprobs = F.log_softmax(logits.float(), dim=-1)
    current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
    sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)

    next_tokens[tokens[:, -1] == self.eot] = self.eot
    tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)

    completed = (tokens[:, -1] == self.eot).all()
    return tokens, completed

GreedyDecoder.update = update
model = whisper.load_model("large")
result = model.transcribe("./data/BASIC5000_0501.wav")

実行後、seq_logitsに各トークンのロジットが格納されています。

print(result["text"])
print(seq_logits)
"""
綾が完璧なドイツ語を話すのは少しも不思議でない。
[tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf]], device='cuda:0'), tensor([[4.5117,   -inf,   -inf,  ...,   -inf,   -inf,   -inf]],
       device='cuda:0'), tensor([[-0.2603,    -inf,    -inf,  ...,  0.0032,  0.1667, -1.6484]],
       device='cuda:0'), tensor([[ 4.5898,    -inf,    -inf,  ...,  1.0732,  1.8193, -0.7383]],
       device='cuda:0'), tensor([[11.7500,    -inf,    -inf,  ...,  6.3633,  5.8086,  4.7891]],
       device='cuda:0'), tensor([[9.2812,   -inf,   -inf,  ..., 5.6094, 6.0938, 5.0898]],
       device='cuda:0'), tensor([[ 3.9414,    -inf,    -inf,  ..., -0.2452,  0.4495, -4.1992]],
       device='cuda:0'), tensor([[17.2656,    -inf,    -inf,  ..., 11.5781, 11.5312, 10.3359]],
       device='cuda:0'), tensor([[8.8281,   -inf,   -inf,  ..., 4.1797, 4.8711, 2.9160]],
       device='cuda:0'), tensor([[11.7422,    -inf,    -inf,  ...,  6.7734,  7.0664,  5.5391]],
       device='cuda:0'), tensor([[9.8203,   -inf,   -inf,  ..., 6.4727, 6.9141, 3.8008]],
       device='cuda:0'), tensor([[12.5000,    -inf,    -inf,  ...,  8.5234,  7.8867,  7.5273]],
       device='cuda:0'), tensor([[15.3594,    -inf,    -inf,  ..., 10.5703, 10.2891, 10.3750]],
       device='cuda:0'), tensor([[5.5664,   -inf,   -inf,  ..., 1.9092, 1.8076, 1.6914]],
       device='cuda:0'), tensor([[10.6250,    -inf,    -inf,  ...,  8.0859,  7.5547,  7.8711]],
       device='cuda:0'), tensor([[11.7109,    -inf,    -inf,  ...,  8.7344,  8.4922,  8.1875]],
       device='cuda:0'), tensor([[ 5.7656,    -inf,    -inf,  ...,  1.7773,  2.5078, -0.5127]],
       device='cuda:0'), tensor([[7.0781,   -inf,   -inf,  ..., 5.0469, 4.4102, 4.3359]],
       device='cuda:0'), tensor([[6.1992,   -inf,   -inf,  ..., 4.1367, 4.4727, 3.5293]],
       device='cuda:0'), tensor([[7.4062,   -inf,   -inf,  ..., 2.6973, 3.5879, 1.2852]],
       device='cuda:0'), tensor([[3.4043,   -inf,   -inf,  ..., 1.1250, 2.0078, 1.3047]],
       device='cuda:0'), tensor([[10.9453,    -inf,    -inf,  ...,  8.7188,  9.0938,  8.2344]],
       device='cuda:0'), tensor([[ 3.3457,    -inf,    -inf,  ...,  0.6387,  0.7017, -1.9385]],
       device='cuda:0'), tensor([[3.3867,   -inf,   -inf,  ..., 1.2852, 1.1562, 1.1045]],
       device='cuda:0'), tensor([[ 8.1250,    -inf,    -inf,  ..., -1.2490, -0.4404, -2.5156]],
       device='cuda:0'), tensor([[   -inf,    -inf,    -inf,  ..., -1.4873, -0.7993, -2.6172]],
       device='cuda:0'), tensor([[   -inf,    -inf,    -inf,  ..., 11.1172, 11.5078, 11.4297]],
       device='cuda:0')]
"""

logitからスコアへの変換

標準シグモイド関数を適用すればlogitをスコア(確率)に変換できます。

import numpy as np

def std_sigmoid(logit):
    return 1/(1+np.exp(-logit))

seq_probs = []
for i in range(len(seq_logits)):
    print("idx: {0}, logit: {1:.2f}, prob: {2:.2f}".format(
        torch.argmax(seq_logits[i]).item(), 
        seq_logits[i].max().item(), 
        std_sigmoid(seq_logits[i].max().item())
    ))
    seq_probs.append(std_sigmoid(seq_logits[i].max().item()))
"""
idx: 50364, logit: 20.109375, prob: 0.9999999981523939
idx: 9261, logit: 12.015625, prob: 0.9999939510814195
idx: 122, logit: 18.6875, prob: 0.9999999923418853
idx: 5142, logit: 14.234375, prob: 0.9999993422070742
idx: 14128, logit: 23.453125, prob: 0.9999999999347715
idx: 40063, logit: 24.53125, prob: 0.9999999999778071
idx: 100, logit: 24.484375, prob: 0.9999999999767422
idx: 3203, logit: 31.453125, prob: 0.999999999999978
idx: 11195, logit: 20.859375, prob: 0.9999999991272526
idx: 8040, logit: 25.875, prob: 0.9999999999942106
idx: 39406, logit: 26.328125, prob: 0.99999999999632
idx: 31348, logit: 28.078125, prob: 0.9999999999993605
idx: 5998, logit: 31.171875, prob: 0.9999999999999709
idx: 11103, logit: 19.484375, prob: 0.9999999965482171
idx: 2659, logit: 28.875, prob: 0.9999999999997118
idx: 35662, logit: 27.578125, prob: 0.9999999999989457
idx: 15686, logit: 17.875, prob: 0.9999999827421723
idx: 2849, logit: 25.140625, prob: 0.9999999999879339
idx: 4801, logit: 22.203125, prob: 0.9999999997723303
idx: 1960, logit: 20.1875, prob: 0.9999999982912435
idx: 8870, logit: 20.65625, prob: 0.9999999989306876
idx: 24686, logit: 31.34375, prob: 0.9999999999999756
idx: 2474, logit: 18.25, prob: 0.99999998813888
idx: 9311, logit: 19.625, prob: 0.9999999970010391
idx: 1543, logit: 13.234375, prob: 0.9999982119354642
idx: 50564, logit: 10.8515625, prob: 0.9999806260634916
idx: 50257, logit: 27.9375, prob: 0.9999999999992639
"""

実行後、seq_probsに各トークンのスコア(確率)の最大値が格納されています。
これをすべて掛けた値が、モデルが音声に対してテキストを出力する確率といえます。

np.prod(np.array(seq_probs)) # 0.9999720823804599

結論として、whisper_largeモデルは音声BASIC5000_0501.wavに対して「綾が完璧なドイツ語を話すのは少しも不思議でない。」と出力し、そのスコアは0.9999720823804599と、ほぼ1.0であることがわかります。

一般のテキストtに対するスコアScore(M, x, t)の求め方

ライブラリのインポート

import whisper
from whisper.decoding import GreedyDecoder
from whisper.decoding import DecodingTask
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
import scipy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

モジュールの書き換え

GreedyDecoderクラスのupdateメソッドを更新します。

from typing import Tuple

seq_logits = []
seq_tokens = []
seq_current_logprobs = []
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
    temperature = self.temperature
    if temperature == 0:
        next_tokens = logits.argmax(dim=-1)
    else:
        next_tokens = Categorical(logits=logits / temperature).sample()

    logprobs = F.log_softmax(logits.float(), dim=-1)
    current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
    sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)

    next_tokens[tokens[:, -1] == self.eot] = self.eot
    tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)

    completed = (tokens[:, -1] == self.eot).all()

    # 追加
    seq_logits.append(logits)
    seq_tokens.append(tokens)
    seq_current_logprobs.append(current_logprobs)
    # ここまで

    return tokens, completed

GreedyDecoder.update = update

続いて、_main_loopメソッドを更新します。

from whisper.tokenizer import get_tokenizer
import numpy as np

tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")
# text = "綾が完璧なドイツ語を話すのは少しも不思議でない。"
text = "アラン君は、運よく税理士試験に合格しました。"
token_ids = tokenizer.encode(text)
token_ids = [50258, 50266, 50359, 50364] + token_ids + [50564]
token_ids_list = [[[50258, 50266, 50359] + [j for j in token_ids[3:i]]] for i in range(3, len(token_ids)+2)]
token_ids_list = [torch.tensor(token_idx).to(device) for token_idx in token_ids_list]

seq_tokens2 = []

def _main_loop(self, audio_features: Tensor, tokens: Tensor):
    n_batch = tokens.shape[0]
    sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
    no_speech_probs = [np.nan] * n_batch

    try:
        seq_tokens2.append(tokens)
        for i in range(
            # self.sample_len
            len(token_ids_list)
        ):
            # この部分でtokensを書き換える
            tokens = token_ids_list[i]
            # ここまで

            logits = self.inference.logits(tokens, audio_features)

            if (
                i == 0 and self.tokenizer.no_speech is not None
            ):  # save no_speech_probs
                probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

            # now we need to consider the logits at the last token only
            logits = logits[:, -1]

            # apply the logit filters, e.g. for suppressing or applying penalty to
            for logit_filter in self.logit_filters:
                logit_filter.apply(logits, tokens)

            # expand the tokens tensor with the selected next tokens
            tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
            seq_tokens2.append(tokens)

            if completed or tokens.shape[-1] > self.n_ctx:
                break
    finally:
        self.inference.cleanup_caching()

    return tokens, sum_logprobs, no_speech_probs

DecodingTask._main_loop = _main_loop

textを一般のテキストにしています。
logits = self.inference.logits(tokens, audio_features) の部分で推論が行われています。audio_featuresは常に同じ入力です。tokensはサイズ(1,n=3以上)となります。

  • tokens[0][0],...,tokens[0][2]:$SpecialToken$
  • token[0][3]:$SOS$
  • token[0][4],...,token[0][n-2]:テキスト
  • token[0][n-1]:$EOS$

スコアの計算

model = whisper.load_model("large", device)
audio_path = './output/adv_BASA7BA.wav'
sr, audio = scipy.io.wavfile.read(audio_path)
result = model.transcribe(audio, language="ja")
# デコード
from whisper.tokenizer import get_tokenizer

pred = torch.argmax(torch.stack(seq_logits).squeeze(), dim=-1)
tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")
token_ids = pred
decoded_word = tokenizer.decode(token_ids[:-1])
decoded_word
# スコアの準備

text = "アラン君は、運よく税理士試験に合格しました。"
token_ids = tokenizer.encode(text)
token_ids = [50258, 50266, 50359, 50364] + token_ids + [50564]
token_ids = torch.tensor(token_ids)

seq_logprobs = torch.stack(seq_logits).squeeze().log_softmax(dim=-1).max(dim=-1).values.squeeze()
# スコア
torch.sum(seq_logprobs[:-1])

 確率は各トークンの予測$P(t_0|f, SOS), P(t_1|f, SOS, t_0), P(t_2|f, SOS, t_0, t_1), ..., P(EOS|f, SOS, t_0, ..., t_n)$の総積であらわされます。スコアは確率に対数をとったものなので$-\log P(t_0|f, SOS)P(t_1|f, SOS, t_0)P(t_2|f, SOS, t_0, t_1)P(EOS|f, SOS, t_0, ..., t_n)=\sum_{i=0}^{n} P(EOS|f, SOS, t_0, ..., t_i)$とあらわすことができます。
 ここで、終了文字を入力したときの予測をスコアに影響させないためにseq_logprobの末尾を捨てています。

Appendix. seq_logitsを用いたデコード

tokenizerを用いることでデコードできます。

from whisper.tokenizer import get_tokenizer, Tokenizer
tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")

k = 1
topk_token_ids = torch.stack([torch.argsort(logits)[0][-k:] for logits in seq_logits]).cpu().T.numpy()

topk_decoded_words = [tokenizer.decode(token_ids) for token_ids in topk_token_ids[::-1]]
topk_decoded_words # ['綾が完璧なドイツ語を話すのは少しも不思議でない。<|endoftext|>']

実行後、topk_decode_wordsに出力テキストが格納されています。確かに、出力は正解テキストと一致しています。

終わりに

いくら探してもwhisperには直接logitやスコアを得るための関数が見つからなかったので、自力で構築しました。もしそのような関数があれば、お知らせいただけますと幸いです。
本記事の作成にあたり、以下の記事を参考にしました。

1
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?