LoginSignup
2
4

任意のビデオエンコーダとHuggingFaceのデコーダモデルをつなげて自作のビデオキャプショニングモデルを作る

Last updated at Posted at 2023-05-20

PyTorch HubなどHuggingFace以外から取得したビデオエンコーダとHuggingFaceのデコーダを使って,ビデオそのものを入力とするキャプショニングモデルの実装方法を紹介します.

実装方法のみを知りたい方は実装の節を参照してください.

やりたいこと

  • ビデオを入力してテキストを生成するビデオキャプショニングモデルを実装したい
    • video encoderでビデオを埋め込み
    • 埋め込みからdecoderでテキスト生成
  • video encoderには,HuggingFaceには無いモデルを使用したい
  • decoderにはHuggingFaceのモデルを使用したい
    • generate()でキャプション生成したい

そもそも

ビデオを入力とするキャプショニングモデルを実装には,HuggingFaceのVisionEncoderDecoderモデルを使うと簡単です.以下のように,encoderdecoderに使用するモデルを(文字列で)指定するだけでインスタンスが生成できます.

from transformers import VisionEncoderDecoderModel


model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
            encoder='MCG-NJU/videomae-base-finetuned-kinetics',
            decoder='gpt2'
        )

VisionEncoderDecoderModelの詳細はこちら.

しかし,encoderdecoderは,どちらもHuggingFaceに存在するモデルしか用いることができません.そのため,torchhubなどから取ってきた他のvideo encoder(たとえば3D ResNet)を使いたい場合は,このクラスを使うことはできません.

generate関数

HuggingFaceのencoder-decoderやdecoderモデルはテキスト生成のためのgenerate関数を持っています.

このgenerate関数は「モデル自体のforward(順伝播) + テキスト生成」を行っており(実際にはテキスト生成ではなく,単語トークンid列の生成),その実体はGenerationMixinクラスに実装されていて,各decoderの親クラスがこれを継承しています.

VisionEncoderDecoderModelgenerate関数を持っているため,これでテキスト生成ができます.しかし今回のように,encoderに任意のモデルを使う場合には,そもそもVisionEncoderDecoderModelencoderに指定できないため,使うことができません.

以下では,decoderの関数の一つをオーバーロードすることでgenerateに埋め込みを渡せるようにする方法を紹介します.

generate関数のドキュメントはこちら(テキスト生成における探索方法の設定など,色々変えることができます).

ビデオキャプショニングモデルの概要

train

  1. ビデオをvideo encoderに入力し,video埋め込みを得る
  2. decoderへの入力には,このvideo埋め込みと,右に1単語シフトした正解テキストを与える(シフト後に先頭に<BOS>トークンが入る)
  3. 次単語予測(シフトしていない正解テキストを予測)を行い,正解テキストと交差エントロピー損失を求める

1がencoderのforward関数で,2と3がdecoderのfoward関数で行われます.

val

  1. ビデオをvideo encoderに入力し,video埋め込みを得る
  2. decoderにvideo埋め込みと,BOSトークンを入力し,テキストを生成する
  3. 正解テキストと比較する

1がencoderのforward関数で,2がdecoderのgenerate関数で行われます
(generate関数内部で,decoderのforward関数が呼び出されています).

実装

設定

1. Encoderの実装

デコーダであるGPT2が受け付ける入力の次元は768なので,エンコーダの最後の全結合層で次元を768に変換します.

class MyX3D(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = torch.hub.load(
            'facebookresearch/pytorchvideo', "x3d_m",
            pretrained=True,
            head_activation=None,
        )
        in_features = self.model.blocks[5].proj.in_features
        self.model.blocks[5].proj = nn.Linear(
            in_features, 768)


    def forward(self, pixel_values):
        output = self.model(pixel_values)
        return output

X3Dの詳細はこちら(head_activationなどはcreate_x3dの引数に渡されます).

2. decoderの実装(GPT2LMHeadModelを継承)

GPT2LMHeadModelを継承して自作のデコーダを作成します.

継承する目的は,prepare_inputs_for_generation()をオーバーライドして変更するためです.

エンコーダの出力が,デコーダのforward()encoder_hidden_statesに引数として渡される(ことが仮定されている)ので,これがgenerate()に入力されても内部でforward()にも渡されるようになればOKのはずです(ビデオの特徴を反映したテキストが生成される).

このprepare_inputs_for_generation()generate()内部で呼び出される関数であり,forward()に渡す引数を選択して用意する役割を持っています.しかしGPT2LMHeadModelの実装はそうはなっていないため,encoder_hidden_statesforward()に渡されず,このままではencoderの出力は利用されることなく捨てられてしまします.

そのため,次のようにして自作クラスを継承して,オーバーライドすることで解決します.

from transformers import (
    GPT2LMHeadModel
)


class MyGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)


    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
    ):
        model_inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values, inputs_embeds, **kwargs
        )
        encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
        model_inputs.update(
            {
                "encoder_hidden_states": encoder_hidden_states,
            }
        )
        return model_inputs

3. VideoCaptionModelの実装

最後にキャプションモデル全体の実装します.

from transformers import (
    GPT2Tokenizer,
    GPT2Config,
)
import torch.nn as nn
import torch



class VideoCaptionModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.video_encoder = MyX3D()
        config = GPT2Config(
            is_encoder_decoder=False,
            add_cross_attention=True,
        )
        self.decoder = MyGPT2LMHeadModel.from_pretrained("gpt2", config=config)
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.decoder.config.pad_token_id = self.tokenizer.pad_token_id
        self.decoder.config.bos_token_id = self.tokenizer.bos_token_id
        self.decoder.config.eos_token_id = self.tokenizer.eos_token_id


    def forward(self, pixel_values, labels, decoder_attention_mask=None):
        encoder_hidden_states = torch.unsqueeze(
            self.video_encoder(pixel_values=pixel_values), 1
        )  # bs * 1 * 768
        decoder_input_ids = shift_tokens(
            labels, self.decoder.config.pad_token_id, self.decoder.config.bos_token_id
        )
        decoder_output = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
        )


        logits = decoder_output.logits
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)
        )
        decoder_output.loss = loss
        return decoder_output


    def generate(self, pixel_values, max_length=20):
        encoder_hidden_states = torch.unsqueeze(
            self.video_encoder(pixel_values=pixel_values), 1
        )
        input_ids = torch.LongTensor(
            [[self.decoder.config.bos_token_id] for _ in range(pixel_values.size()[0])]
        ).to(pixel_values.device)
        generated_ids = self.decoder.generate(
            input_ids,
            encoder_hidden_states=encoder_hidden_states,
            max_new_tokens=max_length,
        )
        return generated_ids

トークン化されたテキストを右シフトする関数は,VisionEncoderDecodershift_tokens_rightを参考に作成します.

def shift_tokens(token_ids, pad_token_id, bos_token_id):
    shifted_ids = input_ids.new_zeros(token_ids.shape)
    shifted_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_ids[:, 0] = bos_token_id
    return shifted_ids

4. trainとval

以下は学習と評価の1バッチ分のコードの例です.videopytorchvideoなどの既存のビデオローダが使えます.評価指標はrougeなどで計算します.

model = VideoCaptionModel()
rouge = evaluate.load("rouge")



def train(batch, model):


    video = batch[0]  # (B, T, C, H, W): video clips with T frames
    text = batch[1]  # (B, L): texts of length L


    with torch.no_grad():
        token = model.tokenizer(text, return_tensors="pt", padding=True)


    output = model(
        pixel_values=video,
        labels=token.input_ids.to(video.device),
        decoder_attention_mask=token.attention_mask.to(video.device),
    )
    loss = output.loss
    return loss



def val(batch, model):
    video = batch[0]
    text = batch[1]


    generated_ids = model.generate(video, max_length=20)
    generated_text = model.tokenizer.batch_decode(
        generated_ids, skip_special_tokens=True
    )
    metrics = rouge.compute(predictions=generated_text, references=text)
    return metrics
2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4