PyTorch HubなどHuggingFace以外から取得したビデオエンコーダとHuggingFaceのデコーダを使って,ビデオそのものを入力とするキャプショニングモデルの実装方法を紹介します.
実装方法のみを知りたい方は実装の節を参照してください.
やりたいこと
- ビデオを入力してテキストを生成するビデオキャプショニングモデルを実装したい
- video encoderでビデオを埋め込み
- 埋め込みからdecoderでテキスト生成
- video encoderには,HuggingFaceには無いモデルを使用したい
- decoderにはHuggingFaceのモデルを使用したい
-
generate()
でキャプション生成したい
-
そもそも
ビデオを入力とするキャプショニングモデルを実装には,HuggingFaceのVisionEncoderDecoder
モデルを使うと簡単です.以下のように,encoder
とdecoder
に使用するモデルを(文字列で)指定するだけでインスタンスが生成できます.
from transformers import VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder='MCG-NJU/videomae-base-finetuned-kinetics',
decoder='gpt2'
)
VisionEncoderDecoderModel
の詳細はこちら.
- https://huggingface.co/docs/transformers/main/en/model_doc/vision-encoder-decoder#transformers.VisionEncoderDecoderModel
- https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder
しかし,encoder
とdecoder
は,どちらもHuggingFaceに存在するモデルしか用いることができません.そのため,torchhub
などから取ってきた他のvideo encoder(たとえば3D ResNet)を使いたい場合は,このクラスを使うことはできません.
generate
関数
HuggingFaceのencoder-decoderやdecoderモデルはテキスト生成のためのgenerate
関数を持っています.
このgenerate
関数は「モデル自体のforward(順伝播) + テキスト生成」を行っており(実際にはテキスト生成ではなく,単語トークンid列の生成),その実体はGenerationMixin
クラスに実装されていて,各decoderの親クラスがこれを継承しています.
VisionEncoderDecoderModel
もgenerate
関数を持っているため,これでテキスト生成ができます.しかし今回のように,encoderに任意のモデルを使う場合には,そもそもVisionEncoderDecoderModel
のencoder
に指定できないため,使うことができません.
以下では,decoderの関数の一つをオーバーロードすることでgenerate
に埋め込みを渡せるようにする方法を紹介します.
generate
関数のドキュメントはこちら(テキスト生成における探索方法の設定など,色々変えることができます).
ビデオキャプショニングモデルの概要
train
- ビデオをvideo encoderに入力し,video埋め込みを得る
- decoderへの入力には,このvideo埋め込みと,右に1単語シフトした正解テキストを与える(シフト後に先頭に
<BOS>
トークンが入る) - 次単語予測(シフトしていない正解テキストを予測)を行い,正解テキストと交差エントロピー損失を求める
1がencoderのforward
関数で,2と3がdecoderのfoward
関数で行われます.
val
- ビデオをvideo encoderに入力し,video埋め込みを得る
- decoderにvideo埋め込みと,BOSトークンを入力し,テキストを生成する
- 正解テキストと比較する
1がencoderのforward
関数で,2がdecoderのgenerate
関数で行われます
(generate
関数内部で,decoderのforward
関数が呼び出されています).
実装
設定
- ビデオエンコーダ : PyTorchHubのX3D_m
- テキストデコーダ : HuggingFaceのGPT2LMHeadModel
1. Encoderの実装
デコーダであるGPT2が受け付ける入力の次元は768なので,エンコーダの最後の全結合層で次元を768に変換します.
class MyX3D(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = torch.hub.load(
'facebookresearch/pytorchvideo', "x3d_m",
pretrained=True,
head_activation=None,
)
in_features = self.model.blocks[5].proj.in_features
self.model.blocks[5].proj = nn.Linear(
in_features, 768)
def forward(self, pixel_values):
output = self.model(pixel_values)
return output
X3Dの詳細はこちら(head_activation
などはcreate_x3d
の引数に渡されます).
2. decoderの実装(GPT2LMHeadModelを継承)
GPT2LMHeadModel
を継承して自作のデコーダを作成します.
継承する目的は,prepare_inputs_for_generation()
をオーバーライドして変更するためです.
エンコーダの出力が,デコーダのforward()
のencoder_hidden_states
に引数として渡される(ことが仮定されている)ので,これがgenerate()
に入力されても内部でforward()
にも渡されるようになればOKのはずです(ビデオの特徴を反映したテキストが生成される).
このprepare_inputs_for_generation()
はgenerate()
内部で呼び出される関数であり,forward()
に渡す引数を選択して用意する役割を持っています.しかしGPT2LMHeadModel
の実装はそうはなっていないため,encoder_hidden_states
はforward()
に渡されず,このままではencoderの出力は利用されることなく捨てられてしまします.
そのため,次のようにして自作クラスを継承して,オーバーライドすることで解決します.
from transformers import (
GPT2LMHeadModel
)
class MyGPT2LMHeadModel(GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
model_inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values, inputs_embeds, **kwargs
)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
model_inputs.update(
{
"encoder_hidden_states": encoder_hidden_states,
}
)
return model_inputs
3. VideoCaptionModelの実装
最後にキャプションモデル全体の実装します.
from transformers import (
GPT2Tokenizer,
GPT2Config,
)
import torch.nn as nn
import torch
class VideoCaptionModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.video_encoder = MyX3D()
config = GPT2Config(
is_encoder_decoder=False,
add_cross_attention=True,
)
self.decoder = MyGPT2LMHeadModel.from_pretrained("gpt2", config=config)
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.decoder.config.pad_token_id = self.tokenizer.pad_token_id
self.decoder.config.bos_token_id = self.tokenizer.bos_token_id
self.decoder.config.eos_token_id = self.tokenizer.eos_token_id
def forward(self, pixel_values, labels, decoder_attention_mask=None):
encoder_hidden_states = torch.unsqueeze(
self.video_encoder(pixel_values=pixel_values), 1
) # bs * 1 * 768
decoder_input_ids = shift_tokens(
labels, self.decoder.config.pad_token_id, self.decoder.config.bos_token_id
)
decoder_output = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
)
logits = decoder_output.logits
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)
)
decoder_output.loss = loss
return decoder_output
def generate(self, pixel_values, max_length=20):
encoder_hidden_states = torch.unsqueeze(
self.video_encoder(pixel_values=pixel_values), 1
)
input_ids = torch.LongTensor(
[[self.decoder.config.bos_token_id] for _ in range(pixel_values.size()[0])]
).to(pixel_values.device)
generated_ids = self.decoder.generate(
input_ids,
encoder_hidden_states=encoder_hidden_states,
max_new_tokens=max_length,
)
return generated_ids
トークン化されたテキストを右シフトする関数は,VisionEncoderDecoder
のshift_tokens_right
を参考に作成します.
def shift_tokens(token_ids, pad_token_id, bos_token_id):
shifted_ids = input_ids.new_zeros(token_ids.shape)
shifted_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_ids[:, 0] = bos_token_id
return shifted_ids
4. trainとval
以下は学習と評価の1バッチ分のコードの例です.video
はpytorchvideoなどの既存のビデオローダが使えます.評価指標はrouge
などで計算します.
model = VideoCaptionModel()
rouge = evaluate.load("rouge")
def train(batch, model):
video = batch[0] # (B, T, C, H, W): video clips with T frames
text = batch[1] # (B, L): texts of length L
with torch.no_grad():
token = model.tokenizer(text, return_tensors="pt", padding=True)
output = model(
pixel_values=video,
labels=token.input_ids.to(video.device),
decoder_attention_mask=token.attention_mask.to(video.device),
)
loss = output.loss
return loss
def val(batch, model):
video = batch[0]
text = batch[1]
generated_ids = model.generate(video, max_length=20)
generated_text = model.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)
metrics = rouge.compute(predictions=generated_text, references=text)
return metrics