今回はWhsiperについてのコードの斜め読みのメモ.
ソースコード
OpenAI/Whisper
https://github.com/openai/whisper
参考文献
本家解説
本家論文
日本語解説
ThothChildrenチャンネルの動画
コード読んでく
上記から
transcribe.py
引数のparseなどが終わると、trascribe関数が呼び出される.
cpuやgpuの環境を確認したのちに下記のdecodeの部分に来る.
ここでは、音声データの最初の30秒をメルスペクトラム変換したのちに、
model,detect_launguageにて各単語の可能性probsを得ている。probsを最大にするlaunguageを取得している。
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
detect_launguageの内部を追っていくと、decoding.pyのdetect_launguage関数が呼ばれている。
def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:
# まず必要があればEncoderで特徴量を得る
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# startoftranscriptのトークンIDだけを与えて、logits関数の中でDecoder処理
n_audio = mel.shape[0]
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# 言語以外含めて推定しているが、言語以外を全てMaskする処理
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
# Maskするときにはマイナス無限大する
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
# 各言語の確率(評価値)を取得する
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
つまりは一度Whisperでdecode処理を実行して、出力のうち言語のスコアでもっとも高い言語を推定という流れ。
transcribe.pyに戻り続きを読む.
30秒の切り出し箇所を処理したのちに実際にその部分のデータを切り出して、decode_with_fallbackへ与えている
while clip_idx < len(seek_clips):
# 切り出しの銭湯や末尾、切り出し量を計算
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
#今回の30秒の箇所を切り出し.
mel_segment = mel[:, seek : seek + segment_size]
#使用する入力音声データを30秒になるように調整(30秒より短い場合)
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
# ここで実際の文字起こしモデル処理を実行
result: DecodingResult = decode_with_fallback(mel_segment)
#推定したtokenの列
tokens = torch.tensor(result.tokens)
では実際に変換処理をしているdecode_with_fallbackを追う
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
#ビームサーチで複数単語候補を追うか、greedySearchにしつつTemperaturでランダム性を与えるかを設定
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
# ここでモデルのdecode処理を実行
decode_result = model.decode(segment, options)
return decode_result
このmodelに飛ぶとmodel.pyのdecodeに飛ぶが、実態は
decode = decode_function
で、decoding.pyのdecodeがセットされている
decode.py
この中では、DecodingTaskのrunにて、実際に入力されたモデルが実行されている.
def decode(
model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel)
return result[0] if single else result
run関数をみると、下記のようにEncoderの処理(_get_audio_features)とDecoderの処理(main_loop)に到達したことがわかる.
def run(self, mel: Tensor) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
#この中でEncoder処理を実行している.出てきた特徴量をaudio_featuresに格納.
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# ビームサーチをしている場合はここで複数候補を持つようにする.
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# ここでDecoderのメイン処理. tokensに予測した文字列が入り、これまでのlogprobsも含まれる.
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
...
return [
DecodingResult(
audio_features=features,
.....
)
...
]
AutioEncoder
_get_audio_featuresは、self.model.encoderを内部で呼び出しており、下記のように宣言されているEncoder.(すごいシンプル)
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
#位置エンコーディングの足し合わせ
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
#指定された回数のEncoderブロック
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
# このサイズは(1,1500,512)
return x
TextDecoder
こちらのDecoderもシンプル.位置エンコーディングを足して、複数回のResidualAttentionBlockを繰り返し実行しているだけ.
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
#位置エンコーディングを加算
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
#xaはaudio_features. blockがcrossAttentionの構成となっている
Attentionブロック.
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
#これが各単語の予測スコア.51865次元の配列.
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
Decoderのメインループ
get_audio_featuresはただ単にencodeのモデルを実行しているのみなので、スキップ。
_main_loopは
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
#最大指定されている224回(sample_len)の実行. Decode処理が繰り返し実行.
for i in range(self.sample_len):
#ここでDecoder実行.各単語の推定.
logits = self.inference.logits(tokens, audio_features)
#no_speechの可能性を検証.
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# logitsのスコアを-infにするなどの調整.絶対にありえないtokenのスコアを下げる.
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# BeamSearhやGreedySearchの実行.
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
# 文章の切れ目であることがわかった場合などは抜ける.
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs
おわりに
Whisperはかなりシンプルなモデル構成...
色々工夫はあるもののTransformerをほぼそのまま活用しており、これであれだけの文字起こし性能がでるのはちょっと意外...