31
27

Whisperのコードを斜め読みする

Last updated at Posted at 2024-07-06

今回はWhsiperについてのコードの斜め読みのメモ.

ソースコード

OpenAI/Whisper
https://github.com/openai/whisper

参考文献

本家解説

本家論文

日本語解説

ThothChildrenチャンネルの動画

コード読んでく

上記から

transcribe.py

引数のparseなどが終わると、trascribe関数が呼び出される.

cpuやgpuの環境を確認したのちに下記のdecodeの部分に来る.
ここでは、音声データの最初の30秒をメルスペクトラム変換したのちに、
model,detect_launguageにて各単語の可能性probsを得ている。probsを最大にするlaunguageを取得している。

    if decode_options.get("language", None) is None:
        if not model.is_multilingual:
            decode_options["language"] = "en"
        else:
            if verbose:
                print(
                    "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
                )
            mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
            _, probs = model.detect_language(mel_segment)
            decode_options["language"] = max(probs, key=probs.get)
            if verbose is not None:
                print(
                    f"Detected language: {LANGUAGES[decode_options['language']].title()}"
                )

detect_launguageの内部を追っていくと、decoding.pyのdetect_launguage関数が呼ばれている。

def detect_language(
    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:
    # まず必要があればEncoderで特徴量を得る
    if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
        mel = model.encoder(mel)

    # startoftranscriptのトークンIDだけを与えて、logits関数の中でDecoder処理
    n_audio = mel.shape[0]
    x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    logits = model.logits(x, mel)[:, 0]

    # 言語以外含めて推定しているが、言語以外を全てMaskする処理
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[list(tokenizer.all_language_tokens)] = False
    # Maskするときにはマイナス無限大する
    logits[:, mask] = -np.inf
    language_tokens = logits.argmax(dim=-1)
    language_token_probs = logits.softmax(dim=-1).cpu()
    # 各言語の確率(評価値)を取得する
    language_probs = [
        {
            c: language_token_probs[i, j].item()
            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
        }
        for i in range(n_audio)
    ]

つまりは一度Whisperでdecode処理を実行して、出力のうち言語のスコアでもっとも高い言語を推定という流れ。

transcribe.pyに戻り続きを読む.
30秒の切り出し箇所を処理したのちに実際にその部分のデータを切り出して、decode_with_fallbackへ与えている

        while clip_idx < len(seek_clips):
            # 切り出しの銭湯や末尾、切り出し量を計算
            seek_clip_start, seek_clip_end = seek_clips[clip_idx]
            time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
            
            segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
            #今回の30秒の箇所を切り出し.
            mel_segment = mel[:, seek : seek + segment_size]

            #使用する入力音声データを30秒になるように調整(30秒より短い場合)
            mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

            # ここで実際の文字起こしモデル処理を実行
            result: DecodingResult = decode_with_fallback(mel_segment)
            #推定したtokenの列
            tokens = torch.tensor(result.tokens)

では実際に変換処理をしているdecode_with_fallbackを追う

     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
        temperatures = (
            [temperature] if isinstance(temperature, (int, float)) else temperature
        )
        decode_result = None

        for t in temperatures:
            #ビームサーチで複数単語候補を追うか、greedySearchにしつつTemperaturでランダム性を与えるかを設定
            kwargs = {**decode_options}
            if t > 0:
                # disable beam_size and patience when t > 0
                kwargs.pop("beam_size", None)
                kwargs.pop("patience", None)
            else:
                # disable best_of when t == 0
                kwargs.pop("best_of", None)

            options = DecodingOptions(**kwargs, temperature=t)
            # ここでモデルのdecode処理を実行
            decode_result = model.decode(segment, options)

        return decode_result

このmodelに飛ぶとmodel.pyのdecodeに飛ぶが、実態は

decode = decode_function

で、decoding.pyのdecodeがセットされている

decode.py

この中では、DecodingTaskのrunにて、実際に入力されたモデルが実行されている.

def decode(
    model: "Whisper",
    mel: Tensor,
    options: DecodingOptions = DecodingOptions(),
    **kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
    
    if single := mel.ndim == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    result = DecodingTask(model, options).run(mel)

    return result[0] if single else result

run関数をみると、下記のようにEncoderの処理(_get_audio_features)とDecoderの処理(main_loop)に到達したことがわかる.

    def run(self, mel: Tensor) -> List[DecodingResult]:
        self.decoder.reset()
        tokenizer: Tokenizer = self.tokenizer
        n_audio: int = mel.shape[0]

        #この中でEncoder処理を実行している.出てきた特徴量をaudio_featuresに格納.
        audio_features: Tensor = self._get_audio_features(mel)  # encoder forward pass
        tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)

        # ビームサーチをしている場合はここで複数候補を持つようにする.
        tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)

        # ここでDecoderのメイン処理. tokensに予測した文字列が入り、これまでのlogprobsも含まれる.
        tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)

        ...
        
        return [
            DecodingResult(
                audio_features=features,
               .....
            )
            ...
        ]

AutioEncoder

_get_audio_featuresは、self.model.encoderを内部で呼び出しており、下記のように宣言されているEncoder.(すごいシンプル)

class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        #位置エンコーディングの足し合わせ
        assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        x = (x + self.positional_embedding).to(x.dtype)

        #指定された回数のEncoderブロック
        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        # このサイズは(1,1500,512)
        return x

TextDecoder

こちらのDecoderもシンプル.位置エンコーディングを足して、複数回のResidualAttentionBlockを繰り返し実行しているだけ.


class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        """
        #位置エンコーディングを加算
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        #xaはaudio_features. blockがcrossAttentionの構成となっている
        Attentionブロック.
        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        #これが各単語の予測スコア.51865次元の配列.
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

Decoderのメインループ

get_audio_featuresはただ単にencodeのモデルを実行しているのみなので、スキップ。
_main_loopは

def _main_loop(self, audio_features: Tensor, tokens: Tensor):
        n_batch = tokens.shape[0]
        sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
        no_speech_probs = [np.nan] * n_batch

        try:
            #最大指定されている224回(sample_len)の実行. Decode処理が繰り返し実行.
            for i in range(self.sample_len):
                #ここでDecoder実行.各単語の推定.
                logits = self.inference.logits(tokens, audio_features)

                #no_speechの可能性を検証.
                if (
                    i == 0 and self.tokenizer.no_speech is not None
                ):  # save no_speech_probs
                    probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                    no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

                # now we need to consider the logits at the last token only
                logits = logits[:, -1]

                # logitsのスコアを-infにするなどの調整.絶対にありえないtokenのスコアを下げる.
                for logit_filter in self.logit_filters:
                    logit_filter.apply(logits, tokens)

                # BeamSearhやGreedySearchの実行.
                tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)

                # 文章の切れ目であることがわかった場合などは抜ける.
                if completed or tokens.shape[-1] > self.n_ctx:
                    break
        finally:
            self.inference.cleanup_caching()

        return tokens, sum_logprobs, no_speech_probs

おわりに

Whisperはかなりシンプルなモデル構成...
色々工夫はあるもののTransformerをほぼそのまま活用しており、これであれだけの文字起こし性能がでるのはちょっと意外...

31
27
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
31
27