はじめに
この記事では、SimVLMのPytorchを用いたスクラッチ開発について説明します。
「SimVLMって何?」という方は、以下の作者の記事を参考にしてください。
注意
これは個人で作成したコードです。誤りなどの可能性がありますが、ご容赦ください。
1. モデル構造
SimVLMの構造は主に Transofomerを基にしたEncoder-Decoder(PrefixLM) モデルになります。
Vision-Transformer やCoAtNet から連想し、画像を直接Encoderに入力している。
具体的には以下の手順で処理される。
1.画像にはN×Nパッチ化を施し、ResNetの最初の3層を用いて畳み込み処理を行いEncoderにおける次元数に適用させ、パッチ数×次元数に変形し仮トークンとして扱う。
2.テキストには基本的なトークン化を行い、埋め込み層に入力し、トークン数×次元数に変形する。
3.位置情報を取り込むために、画像の特徴量とテキストの特徴量それぞれに対して、絶対位置埋め込みを行う。
4.エンコーダに入力後、画像に対応するトークンに対して、相対的な位置情報を取り込むため、相対位置埋め込みを行う。
5.エンコーダの出力をデコーダの Key と Valueに渡し、デコーダの入力を queryとしてPrefixLMを用いて適切な画像と文脈間の特徴を学習する。
6.デコーダにヘッドを追加し、対応するタスクを解く。
比較的理解が難しい PrefixLM や 相対位置埋め込み の仕組みについて知りたい方は、
作者の記事(【実装】T5(Text-to-Text Transfer Transformer)をスクラッチ開発してみた)を参考にしてください。
3. コード
import torch
import torch.nn as nn
class SimVLM(nn.Module):
def __init__(self, batch_size, num_patch, num_embedding, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
super().__init__()
# モデルのパラメータを初期化
self.batch_size = batch_size
self.seq_length = seq_length
self.hidden_size = hidden_size
self.num_layer = num_layer
self.num_heads = num_heads
self.ffn_hidden_size = ffn_hidden_size
self.num_patch = num_patch
# エンコーダーおよびデコーダーの埋め込み層を定義
self.encoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
self.decoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
# 位置エンベディングの初期化
self._setupPositionalEmbedding(num_patch, seq_length, hidden_size)
# ResNetモデルの初期化
self.resnet = ResNet(num_patch, hidden_size)
# エンコーダーとデコーダーの初期化
self.encoder = Encoder(batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
self.decoder = Decoder(batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
def _setupPositionalEmbedding(self, num_patch, seq_length, hidden_size):
# 位置エンベディングを初期化するメソッド
image_positional_embedding_module = nn.Embedding(num_patch, hidden_size)
encoder_positional_embedding_module = nn.Embedding(seq_length[0], hidden_size)
decoder_positional_embedding_module = nn.Embedding(seq_length[1], hidden_size)
image_position_ids = torch.tensor(list(range(num_patch))).expand(self.batch_size, -1)
encoder_position_ids = torch.tensor(list(range(seq_length[0]))).expand(self.batch_size, -1)
decoder_position_ids = torch.tensor(list(range(seq_length[1]))).expand(self.batch_size, -1)
# 画像、エンコーダー、デコーダーそれぞれの位置エンベディングを初期化
self.image_positional_embs = image_positional_embedding_module(image_position_ids)
self.encoder_positional_embs = encoder_positional_embedding_module(encoder_position_ids)
self.decoder_positional_embs = decoder_positional_embedding_module(decoder_position_ids)
def forward(self, images, encoder_input_ids, decoder_input_ids):
# モデルのフォワードパスを定義
# 画像からエンコーダーへの処理
encoder_image_output = self.resnet(images) + self.image_positional_embs
# エンコーダーのトークン埋め込みと位置エンベディングの結合
encoder_tokens_embedding = self.encoder_embedding(encoder_input_ids) + self.encoder_positional_embs
encoder_concat_tokens = torch.cat([encoder_image_output, encoder_tokens_embedding], dim=1)
# エンコーダーの処理
encoder_output = self.encoder(encoder_concat_tokens)
# デコーダーのトークン埋め込みの取得
decoder_embedding = self.decoder_embedding(decoder_input_ids)
# デコーダーの処理
decoder_output = self.decoder(decoder_embedding, encoder_output)
return decoder_output
import torch
import torch.nn as nn
from torchvision.models import resnet18
class ResNet(nn.Module):
def __init__(self, num_patch, output_size):
super().__init__()
self.num_patch = num_patch
self._initModel(output_size)
def _initModel(self, output_size):
# 事前学習済みのResNet-18モデルをロード
model = resnet18(pretrained=True)
in_features = 1
# 入力チャネル数を調整
conv1_in_channels = self.num_patch
conv1_out_channels = model.conv1.out_channels
model.conv1 = nn.Conv2d(conv1_in_channels, conv1_out_channels, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# 逆畳み込み層の入力チャネル数を調整
deconv_in_channels = model.layer3[1].conv2.out_channels
deconv = nn.Conv2d(deconv_in_channels, self.num_patch, kernel_size=(1, 1), stride=(1, 1), bias=False)
# 線形層を設定
self.fc = nn.Linear(in_features, output_size)
# モデルの構築
my_model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2, model.layer3, deconv, model.avgpool)
self.model = my_model
def patchImages(self, images):
# 画像をトークンに変換する関数
batch_size, num_channel, width, height = images.shape
patch_size = int(width / ((self.num_patch / num_channel) ** 0.5))
patch_window = torch.ones((patch_size, patch_size), dtype=torch.long)
# パッチウィンドウを画像のチャネルに拡張
patch_window = patch_window.unsqueeze(0).expand(num_channel, patch_size, patch_size) \
.unsqueeze(0).expand(batch_size, num_channel, patch_size, patch_size)
token_list = []
# 画像をパッチに分割
for row_idx in range(0, width, patch_size):
for col_idx in range(0, height, patch_size):
patch = images[:, :, row_idx: row_idx + patch_size, col_idx: col_idx + patch_size]
token_list.append(patch)
# パッチをスタックし、形状を整える
patched_images = torch.stack(token_list, dim=0).transpose(0, 1).reshape(batch_size, self.num_patch, patch_size, patch_size)
return patched_images
画像をトークン化するため、論文でも使用されたResNetを使用しました。論文では、3層のブロックを使用することで経験的に最適な結果が得られるとのことでした。
import torch.nn as nn
class EncoderLayer(nn.Module):
def __init__(self, batch_size, num_patch, num_heads, seq_length, hidden_size, ffn_hidden_size):
super().__init__()
# Multi-Head Attention レイヤー
self.multi_head_attention = MultiHeadAttention(batch_size, num_patch, num_heads, seq_length, hidden_size, check_positional_embedding=True, check_mask=False)
# Add & Norm レイヤー1
self.add_norm1 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=True)
# FeedForward レイヤー
self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)
# Add & Norm レイヤー2
self.add_norm2 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=True)
def forward(self, tokens):
# 入力トークンを保持しておく(Skip Connection用)
skip1 = tokens
# Multi-Head Attention レイヤーの処理
multi_head_attention = self.multi_head_attention(tokens, tokens, tokens)
# Add & Norm レイヤー1の処理
add_norm1 = self.add_norm1(multi_head_attention, skip1)
# Skip Connectionを保持しておく
skip2 = add_norm1
# FeedForward レイヤーの処理
feed_forward = self.feed_forward(add_norm1)
# Add & Norm レイヤー2の処理
add_norm2 = self.add_norm2(feed_forward, skip2)
# 処理結果を返す
tokens = add_norm2
return tokens
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
super().__init__()
# 複数のエンコーダーレイヤーを構築
self._setupEncoderLayer(batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
def _setupEncoderLayer(self, batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
# エンコーダーレイヤーをリストとして保持するためのモジュール
encoder_layer_list = []
# 指定された数だけエンコーダーレイヤーを構築
for _ in range(num_layer):
encoder_layer = EncoderLayer(batch_size, num_patch, num_heads, seq_length, hidden_size, ffn_hidden_size)
encoder_layer_list.append(encoder_layer)
# モジュールリストとしてエンコーダーレイヤーを保持
self.encoder_module = nn.ModuleList(encoder_layer_list)
def forward(self, encoder_embedding):
# 入力トークンを保持しておく
tokens = encoder_embedding
# 各エンコーダーレイヤーを順次適用
for encoder_layer in self.encoder_module:
tokens = encoder_layer(tokens)
# エンコーダーの出力を返す
encoder_output = tokens
return encoder_output
あとにも記述しますが、Encoder_layerにおける画像の特徴には絶対位置埋め込みに加え、相対位置埋め込みを適応しています。
import torch.nn as nn
class DecoderLayer(nn.Module):
def __init__(self, batch_size, num_patch, num_heads, seq_length, hidden_size, ffn_hidden_size):
super().__init__()
# マスク付きのMulti-Head Attention レイヤー
self.masked_multi_head_attention = MultiHeadAttention(batch_size, num_patch, num_heads, seq_length, hidden_size, check_positional_embedding=False, check_mask=True)
# Add & Norm レイヤー1
self.add_norm1 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=False)
# Cross-Attention レイヤー
self.cross_multi_head_attention = MultiHeadAttention(batch_size, num_patch, num_heads, seq_length, hidden_size, check_positional_embedding=False, check_mask=False)
# Add & Norm レイヤー2
self.add_norm2 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=False)
# FeedForward レイヤー
self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)
# Add & Norm レイヤー3
self.add_norm3 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=False)
def forward(self, tokens, output_encoder):
# 入力トークンを保持しておく
skip1 = tokens
# マスク付きのMulti-Head Attention レイヤーの処理
masked_multi_head_attention = self.masked_multi_head_attention(tokens, tokens, tokens)
# Add & Norm レイヤー1の処理
add_norm1 = self.add_norm1(masked_multi_head_attention, skip1)
# Skip Connectionを保持しておく
skip2 = add_norm1
# Cross-Attention レイヤーの処理
cross_multi_head_attention = self.cross_multi_head_attention(tokens, output_encoder, output_encoder)
# Add & Norm レイヤー2の処理
add_norm2 = self.add_norm2(cross_multi_head_attention, skip2)
# Skip Connectionを保持しておく
skip3 = add_norm2
# FeedForward レイヤーの処理
feed_forward = self.feed_forward(tokens)
# Add & Norm レイヤー3の処理
add_norm3 = self.add_norm3(feed_forward, skip3)
# 処理結果を返す
tokens = add_norm3
return tokens
import torch.nn as nn
class Decoder(nn.Module):
def __init__(self, batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
super().__init__()
# 複数のデコーダーレイヤーを構築
self._setupDecoderLayer(batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
def _setupDecoderLayer(self, batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
# デコーダーレイヤーをリストとして保持するためのモジュール
decoder_layer_list = []
# 指定された数だけデコーダーレイヤーを構築
for _ in range(num_layer):
decoder_layer = DecoderLayer(batch_size, num_patch, num_heads, seq_length, hidden_size, ffn_hidden_size)
decoder_layer_list.append(decoder_layer)
# モジュールリストとしてデコーダーレイヤーを保持
self.decoder_module = nn.ModuleList(decoder_layer_list)
def forward(self, decoder_embedding, encoder_output):
# 入力トークンを保持しておく
tokens = decoder_embedding
# 各デコーダーレイヤーを順次適用
for decoder_layer in self.decoder_module:
tokens = decoder_layer(tokens, encoder_output)
# デコーダーの出力を返す
decoder_output = tokens
return decoder_output
Decoder_layerではT5と比較しても、特別な処理は行っていません。Masked-MultiHeadAttentionとEncoderの出力ベクトルとのMulti-HeadCrossAttentionを行っています。
import torch
import torch.nn as nn
import numpy as np
class MultiHeadAttention(nn.Module):
def __init__(self, batch_size, num_patch, num_heads, seq_length, hidden_size, check_positional_embedding, check_mask):
super().__init__()
# Query、Key、Value用の線形変換モジュールを構築
self._setupHeadQKV(num_heads, hidden_size)
# モジュールのパラメータや設定を保持
self.batch_size = batch_size
self.num_patch = num_patch
self.num_heads = num_heads
self.seq_length = seq_length
self.hidden_size = hidden_size
self.check_positional_embedding = check_positional_embedding
self.check_mask = check_mask
# Softmax関数(Attentionの重み計算用)
self.softmax = nn.Softmax(dim=-1)
def _setupHeadQKV(self, num_heads, hidden_size):
# Query、Key、Value用のモジュールをリストとして保持するためのモジュール
query_module = []
key_module = []
value_module = []
# ヘッドごとの隠れ層サイズ
head_hidden_size = int(hidden_size / num_heads)
# 指定されたヘッド数だけQuery、Key、Value用のモジュールを構築
for _ in range(num_heads):
query_module.append(nn.Linear(hidden_size, head_hidden_size))
key_module.append(nn.Linear(hidden_size, head_hidden_size))
value_module.append(nn.Linear(hidden_size, head_hidden_size))
# モジュールリストとして保持
self.query_module = nn.ModuleList(query_module)
self.key_module = nn.ModuleList(key_module)
self.value_module = nn.ModuleList(value_module)
def _outputRelativePositionalEmbeddingScalar(self, query, batch_size, num_patch, seq_length, hidden_size, num_heads):
# 相対位置エンベディングを計算する関数
# 入力のシーケンス長を取得
seq_length = seq_length[0]
# Embeddingモジュールを保持するためのリスト
embed_Module = []
# ヘッドごとの隠れ層サイズ
head_hidden_size = int(hidden_size / num_heads)
# 位置情報のIDを作成
position_ids = torch.tensor(list(range(num_patch + seq_length)), dtype=torch.long).reshape(1, num_patch + seq_length).expand(batch_size, num_patch + seq_length)
# 各ヘッドごとにEmbeddingモジュールを構築
for id in range(num_heads):
embed_Module.append(nn.Embedding(num_patch + seq_length, head_hidden_size))
# モジュールリストとしてEmbeddingモジュールを保持
self.embed_module = nn.ModuleList(embed_Module)
# 各ヘッドごとに相対位置エンベディングを計算
for id in range(num_heads):
head_query = self.query_module[id](query)
tmp_relative_position_embedding_scalar = (head_query@(self.embed_module[id](position_ids).transpose(1, 2)))\
.reshape(1, batch_size, num_patch + seq_length, num_patch + seq_length)[:, :, :num_patch, :num_patch]
# 列のパディングを追加
col_pad = torch.zeros((1, batch_size, num_patch, seq_length), dtype=torch.float)
tmp_relative_position_embedding_scalar = torch.cat([tmp_relative_position_embedding_scalar, col_pad], dim=3)
# 行のパディングを追加
row_pad = torch.zeros((1, batch_size, seq_length, num_patch + seq_length))
tmp_relative_position_embedding_scalar = torch.cat([tmp_relative_position_embedding_scalar, row_pad], dim=2)
# 初めて計算するヘッドの場合は、相対位置エンベディングをそのまま保持
if id == 0:
relative_position_embedding_scalar = tmp_relative_position_embedding_scalar
else:
# すでに計算済みのヘッドがある場合は、テンソルを連結
relative_position_embedding_scalar = torch.cat([relative_position_embedding_scalar, tmp_relative_position_embedding_scalar], dim=0)
return relative_position_embedding_scalar
def _outputAttention(self, query, key, value, batch_size, num_patch, seq_length, hidden_size, num_heads, check_positional_embedding, check_mask):
# Attentionスコアを計算する関数
# マスクや位置エンベディングの有無によって、処理を分岐
if check_positional_embedding:
seq_length1 = seq_length2 = num_patch + self.seq_length[0]
else:
if check_mask:
seq_length1 = seq_length2 = self.seq_length[1]
else:
seq_length1 = self.seq_length[1]
seq_length2 = num_patch + self.seq_length[0]
# ヘッドごとの隠れ層サイズ
head_hidden_size = int(hidden_size / num_heads)
# マスクマップを作成
mask_map = torch.tensor(np.tril(np.ones((seq_length1, seq_length2))), dtype=torch.long)
# 位置エンベディングが指定されている場合、相対位置エンベディングを計算
if check_positional_embedding:
relative_position_embedding_scalar = self._outputRelativePositionalEmbeddingScalar(query, batch_size, num_patch, seq_length, hidden_size, num_heads)
else:
relative_position_embedding_scalar = None
# 各ヘッドごとにAttentionスコアを計算
for id in range(num_heads):
head_query = self.query_module[id](query)
head_key = self.key_module[id](key)
head_value = self.value_module[id](value)
if check_positional_embedding:
# 位置エンベディングが指定されている場合、Attentionスコアに相対位置エンベディングを加算
tmp_head_attention = self.softmax(((head_query@head_key.transpose(1, 2)) / (head_hidden_size) + relative_position_embedding_scalar[id]))@head_value
else:
# 位置エンベディングが指定されていない場合
if check_mask:
# マスクが指定されている場合、Attentionスコアにマスクを適用
tmp_head_attention = self.softmax((mask_map * (head_query@head_key.transpose(1, 2)) / (head_hidden_size)))@head_value
else:
# マスクが指定されておらず、位置エンベディングもない場合、通常のAttentionスコア計算
tmp_head_attention = self.softmax((head_query@head_key.transpose(1, 2)) / (head_hidden_size))@head_value
# はじめて計算するヘッドの場合は、Attentionスコアをそのまま保持
if id == 0:
head_attention = tmp_head_attention
else:
# すでに計算済みのヘッドがある場合は、テンソルを連結
head_attention = torch.cat([head_attention, tmp_head_attention], dim=-1)
# 出力のAttentionスコアを返す
output_attention = head_attention
return output_attention
def forward(self, query, key, value):
# フォワード関数
# Attentionスコアを計算
output_attention = self._outputAttention(query, key, value, self.batch_size, self.num_patch, self.seq_length, self.hidden_size, self.num_heads,
self.check_positional_embedding, self.check_mask)
return output_attention
このクラスでは、少し面倒な処理を行っています。T5のコードをを流用しているため複雑ですが、EncoderのAttentionには相対位置埋め込みを行い、DecoderにはMask化、またはCrossAttentionを行えるような分岐処理を施しています(普通に考えてそれぞれのAttentionごとに分けた方が可読性は高いです…)。
import torch.nn as nn
class AddNorm(nn.Module):
def __init__(self, batch_size, num_patch, seq_length, hidden_size, check_encoder):
super().__init__()
# AddNormモジュールを構築
self._setupAddNormModule(batch_size, num_patch, seq_length, hidden_size, check_encoder)
def _setupAddNormModule(self, batch_size, num_patch, seq_length, hidden_size, check_encoder):
# エンコーダーの場合、シーケンス長はエンコーダーのシーケンス長とパッチ数の合計
if check_encoder:
seq_length = seq_length[0] + num_patch
else:
# デコーダーの場合、シーケンス長はデコーダーのシーケンス長
seq_length = seq_length[1]
# LayerNormモジュールを構築
self.layer_norm = nn.LayerNorm((batch_size, seq_length, hidden_size))
def forward(self, tokens, skipped_tokens):
# 入力トークンにスキップしたトークンを加算
tokens += skipped_tokens
# LayerNormを適用して出力トークンを生成
tokens = self.layer_norm(tokens)
return tokens
import torch.nn as nn
class FeedForward(nn.Module):
def __init__(self, hidden_size, ffn_hidden_size):
super().__init__()
# FeedForwardモジュールを構築
self._setupFeedForwardModule(hidden_size, ffn_hidden_size)
def _setupFeedForwardModule(self, hidden_size, ffn_hidden_size):
# 1つ目の全結合層とReLU活性化関数
dense1 = nn.Linear(hidden_size, ffn_hidden_size)
relu1 = nn.ReLU()
# 2つ目の全結合層とReLU活性化関数
dense2 = nn.Linear(ffn_hidden_size, hidden_size)
relu2 = nn.ReLU()
# モジュールリストとして保持
self.feed_foward_module = nn.ModuleList([dense1, relu1, dense2, relu2])
def forward(self, tokens):
# フォワード関数
# モジュールリスト内の各モジュールを順に適用
for module in self.feed_foward_module:
tokens = module(tokens)
# 出力トークンを返す
return tokens
4. 実行結果
from transformers import AutoTokenizer
import numpy as np
# BERTトークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") # T5Tokenizerでは文頭トークン<s>が表示されなかったため、bertで代用
# 入力となる英語のトークン化
enocoder_input_tokens = ["Translate English to German : Two brown and white dogs",
"Translate English to German : The man plaing soccoer",
"Translate English to German : The birds is flying"]
encoder_input_tokenize = tokenizer(enocoder_input_tokens, padding=True, return_tensors="pt", return_length=True)
encoder_input_ids = encoder_input_tokenize["input_ids"][:, 1:] # 先頭のトークン<s>を除外
encoder_attention_mask = encoder_input_tokenize["attention_mask"][:, 1:]
encoder_attention_mask = torch.where(encoder_input_ids == 102, 0, encoder_attention_mask) # 文末トークンを除外
encoder_input_ids = torch.where(encoder_input_ids == 102, 0, encoder_input_ids) # 文末トークンを除外
# 出力となるドイツ語のトークン化
decoder_input_tokens = ["Zwei braune und weiße Hunde",
"Der Mann spielt Fußball",
"Die Vögel fliegen"]
decoder_input_tokenize = tokenizer(decoder_input_tokens, padding=True, return_tensors="pt", return_length=True)
decoder_input_ids = decoder_input_tokenize["input_ids"]
decoder_attention_mask = decoder_input_tokenize["attention_mask"]
encoder_max_length = encoder_input_tokenize["length"][0].item() - 1
decoder_max_length = decoder_input_tokenize["length"][0].item()
# 画像の生成と前処理
images = torch.randn(3, 3, 256, 256, dtype=torch.float)
patch_size = 16
num_patch = int((images.shape[2] / patch_size) ** 2) * images.shape[1]
num_embedding = torch.max(torch.concat([encoder_input_ids, decoder_input_ids], dim=-1)) + 1
# SimVLMモデルの初期化と実行
kwargs = {
"batch_size": 3,
"num_patch": num_patch,
"num_embedding": num_embedding,
"seq_length": (encoder_max_length, decoder_max_length),
"hidden_size": 512,
"num_layer": 12,
"num_heads": 8,
"ffn_hidden_size": 3072
}
simvlm = SimVLM(**kwargs)
# モデルの出力の形状を表示
print(encoder_max_length, decoder_max_length)
print(simvlm(images, encoder_input_ids, decoder_input_ids).shape)
論文ではEncoderへの入力では特殊トークンを付与すると記載されていなかったため、今回は除外しています。またDecoderには分類タスク等で必要なので付与しています。他設定はT5における基本的設定と同様となっています。patch_sizeは自由に決めてもらって構いません。
14 12
torch.Size([3, 12, 512])
無事、正しくバッチサイズ × Decoderのトークン数 × 次元数
で出力されていますね。
まとめ
ここまで読んでくださり、ありがとうございます。
今回の記事では、SimVLMのスクラッチ開発をしてみました。T5のコードを大部分を流用できたので実装は比較的簡単でしたが、EncoderにおけるAttentionのヘッドごとの画像トークン部分にのみ相対位置埋め込みを適用するという処理に時間がかかりました。
しかし、やはり自分の手で実装することで処理の流れをより深いレベルで理解できますね。次は音声認識と自然言語処理に関する論文の実装を考えています(音声認識への理解はCTC程度で止まっていますが…)。
参考文献
SimVLM: Simple Visual Language Model Pretraining with Weak Supervision