LoginSignup
0
0
記事投稿キャンペーン 「AI、機械学習」

【実装】SimVLM(Simple Visual Language Model)をスクラッチ開発してみた

Last updated at Posted at 2023-11-13

はじめに

この記事では、SimVLMのPytorchを用いたスクラッチ開発について説明します。

「SimVLMって何?」という方は、以下の作者の記事を参考にしてください。

【論文読み】SimVLMの論文を要約してみた

注意
これは個人で作成したコードです。誤りなどの可能性がありますが、ご容赦ください。

1. モデル構造

image.png

SimVLMの構造は主に Transofomerを基にしたEncoder-Decoder(PrefixLM) モデルになります。

Vision-TransformerCoAtNet から連想し、画像を直接Encoderに入力している。

具体的には以下の手順で処理される。

1.画像にはN×Nパッチ化を施し、ResNetの最初の3層を用いて畳み込み処理を行いEncoderにおける次元数に適用させ、パッチ数×次元数に変形し仮トークンとして扱う。

2.テキストには基本的なトークン化を行い、埋め込み層に入力し、トークン数×次元数に変形する。

3.位置情報を取り込むために、画像の特徴量とテキストの特徴量それぞれに対して、絶対位置埋め込みを行う。

4.エンコーダに入力後、画像に対応するトークンに対して、相対的な位置情報を取り込むため、相対位置埋め込みを行う。

5.エンコーダの出力をデコーダの KeyValueに渡し、デコーダの入力を queryとしてPrefixLMを用いて適切な画像と文脈間の特徴を学習する。

6.デコーダにヘッドを追加し、対応するタスクを解く。

比較的理解が難しい PrefixLM相対位置埋め込み の仕組みについて知りたい方は、

作者の記事(【実装】T5(Text-to-Text Transfer Transformer)をスクラッチ開発してみた)を参考にしてください。

3. コード

SimVLM
import torch
import torch.nn as nn

class SimVLM(nn.Module):
    def __init__(self, batch_size, num_patch, num_embedding, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()

        # モデルのパラメータを初期化
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.num_heads = num_heads
        self.ffn_hidden_size = ffn_hidden_size
        self.num_patch = num_patch

        # エンコーダーおよびデコーダーの埋め込み層を定義
        self.encoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
        self.decoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)

        # 位置エンベディングの初期化
        self._setupPositionalEmbedding(num_patch, seq_length, hidden_size)

        # ResNetモデルの初期化
        self.resnet = ResNet(num_patch, hidden_size)

        # エンコーダーとデコーダーの初期化
        self.encoder = Encoder(batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
        self.decoder = Decoder(batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)

    def _setupPositionalEmbedding(self, num_patch, seq_length, hidden_size):
        # 位置エンベディングを初期化するメソッド
        image_positional_embedding_module = nn.Embedding(num_patch, hidden_size)
        encoder_positional_embedding_module = nn.Embedding(seq_length[0], hidden_size)
        decoder_positional_embedding_module = nn.Embedding(seq_length[1], hidden_size)

        image_position_ids = torch.tensor(list(range(num_patch))).expand(self.batch_size, -1)
        encoder_position_ids = torch.tensor(list(range(seq_length[0]))).expand(self.batch_size, -1)
        decoder_position_ids = torch.tensor(list(range(seq_length[1]))).expand(self.batch_size, -1)

        # 画像、エンコーダー、デコーダーそれぞれの位置エンベディングを初期化
        self.image_positional_embs = image_positional_embedding_module(image_position_ids)
        self.encoder_positional_embs = encoder_positional_embedding_module(encoder_position_ids)
        self.decoder_positional_embs = decoder_positional_embedding_module(decoder_position_ids)

    def forward(self, images, encoder_input_ids, decoder_input_ids):
        # モデルのフォワードパスを定義

        # 画像からエンコーダーへの処理
        encoder_image_output = self.resnet(images) + self.image_positional_embs

        # エンコーダーのトークン埋め込みと位置エンベディングの結合
        encoder_tokens_embedding = self.encoder_embedding(encoder_input_ids) + self.encoder_positional_embs
        encoder_concat_tokens = torch.cat([encoder_image_output, encoder_tokens_embedding], dim=1)

        # エンコーダーの処理
        encoder_output = self.encoder(encoder_concat_tokens)

        # デコーダーのトークン埋め込みの取得
        decoder_embedding = self.decoder_embedding(decoder_input_ids)

        # デコーダーの処理
        decoder_output = self.decoder(decoder_embedding, encoder_output)

        return decoder_output
ResNet
import torch
import torch.nn as nn
from torchvision.models import resnet18

class ResNet(nn.Module):
    def __init__(self, num_patch, output_size):
        super().__init__()
        self.num_patch = num_patch
        self._initModel(output_size)

    def _initModel(self, output_size):
        # 事前学習済みのResNet-18モデルをロード
        model = resnet18(pretrained=True)
        in_features = 1

        # 入力チャネル数を調整
        conv1_in_channels = self.num_patch
        conv1_out_channels = model.conv1.out_channels
        model.conv1 = nn.Conv2d(conv1_in_channels, conv1_out_channels, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

        # 逆畳み込み層の入力チャネル数を調整
        deconv_in_channels = model.layer3[1].conv2.out_channels
        deconv = nn.Conv2d(deconv_in_channels, self.num_patch, kernel_size=(1, 1), stride=(1, 1), bias=False)

        # 線形層を設定
        self.fc = nn.Linear(in_features, output_size)

        # モデルの構築
        my_model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2, model.layer3, deconv, model.avgpool)
        self.model = my_model

    def patchImages(self, images):
        # 画像をトークンに変換する関数
        batch_size, num_channel, width, height = images.shape
        patch_size = int(width / ((self.num_patch / num_channel) ** 0.5))
        patch_window = torch.ones((patch_size, patch_size), dtype=torch.long)

        # パッチウィンドウを画像のチャネルに拡張
        patch_window = patch_window.unsqueeze(0).expand(num_channel, patch_size, patch_size) \
            .unsqueeze(0).expand(batch_size, num_channel, patch_size, patch_size)

        token_list = []

        # 画像をパッチに分割
        for row_idx in range(0, width, patch_size):
            for col_idx in range(0, height, patch_size):
                patch = images[:, :, row_idx: row_idx + patch_size, col_idx: col_idx + patch_size]
                token_list.append(patch)

        # パッチをスタックし、形状を整える
        patched_images = torch.stack(token_list, dim=0).transpose(0, 1).reshape(batch_size, self.num_patch, patch_size, patch_size)

        return patched_images

画像をトークン化するため、論文でも使用されたResNetを使用しました。論文では、3層のブロックを使用することで経験的に最適な結果が得られるとのことでした。

Encoderlayer
import torch.nn as nn

class EncoderLayer(nn.Module):
    def __init__(self, batch_size, num_patch, num_heads, seq_length, hidden_size, ffn_hidden_size):
        super().__init__()

        # Multi-Head Attention レイヤー
        self.multi_head_attention = MultiHeadAttention(batch_size, num_patch, num_heads, seq_length, hidden_size, check_positional_embedding=True, check_mask=False)

        # Add & Norm レイヤー1
        self.add_norm1 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=True)

        # FeedForward レイヤー
        self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)

        # Add & Norm レイヤー2
        self.add_norm2 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=True)

    def forward(self, tokens):
        # 入力トークンを保持しておく(Skip Connection用)
        skip1 = tokens

        # Multi-Head Attention レイヤーの処理
        multi_head_attention = self.multi_head_attention(tokens, tokens, tokens)

        # Add & Norm レイヤー1の処理
        add_norm1 = self.add_norm1(multi_head_attention, skip1)

        # Skip Connectionを保持しておく
        skip2 = add_norm1

        # FeedForward レイヤーの処理
        feed_forward = self.feed_forward(add_norm1)

        # Add & Norm レイヤー2の処理
        add_norm2 = self.add_norm2(feed_forward, skip2)

        # 処理結果を返す
        tokens = add_norm2

        return tokens
Encoder
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()

        # 複数のエンコーダーレイヤーを構築
        self._setupEncoderLayer(batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)

    def _setupEncoderLayer(self, batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        # エンコーダーレイヤーをリストとして保持するためのモジュール
        encoder_layer_list = []

        # 指定された数だけエンコーダーレイヤーを構築
        for _ in range(num_layer):
            encoder_layer = EncoderLayer(batch_size, num_patch, num_heads, seq_length, hidden_size, ffn_hidden_size)
            encoder_layer_list.append(encoder_layer)

        # モジュールリストとしてエンコーダーレイヤーを保持
        self.encoder_module = nn.ModuleList(encoder_layer_list)

    def forward(self, encoder_embedding):
        # 入力トークンを保持しておく
        tokens = encoder_embedding

        # 各エンコーダーレイヤーを順次適用
        for encoder_layer in self.encoder_module:
            tokens = encoder_layer(tokens)

        # エンコーダーの出力を返す
        encoder_output = tokens

        return encoder_output

あとにも記述しますが、Encoder_layerにおける画像の特徴には絶対位置埋め込みに加え、相対位置埋め込みを適応しています。

Decoderlayer
import torch.nn as nn

class DecoderLayer(nn.Module):
    def __init__(self, batch_size, num_patch, num_heads, seq_length, hidden_size, ffn_hidden_size):
        super().__init__()

        # マスク付きのMulti-Head Attention レイヤー
        self.masked_multi_head_attention = MultiHeadAttention(batch_size, num_patch, num_heads, seq_length, hidden_size, check_positional_embedding=False, check_mask=True)

        # Add & Norm レイヤー1
        self.add_norm1 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=False)

        # Cross-Attention レイヤー
        self.cross_multi_head_attention = MultiHeadAttention(batch_size, num_patch, num_heads, seq_length, hidden_size, check_positional_embedding=False, check_mask=False)

        # Add & Norm レイヤー2
        self.add_norm2 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=False)

        # FeedForward レイヤー
        self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)

        # Add & Norm レイヤー3
        self.add_norm3 = AddNorm(batch_size, num_patch, seq_length, hidden_size, check_encoder=False)

    def forward(self, tokens, output_encoder):
        # 入力トークンを保持しておく
        skip1 = tokens

        # マスク付きのMulti-Head Attention レイヤーの処理
        masked_multi_head_attention = self.masked_multi_head_attention(tokens, tokens, tokens)

        # Add & Norm レイヤー1の処理
        add_norm1 = self.add_norm1(masked_multi_head_attention, skip1)

        # Skip Connectionを保持しておく
        skip2 = add_norm1

        # Cross-Attention レイヤーの処理
        cross_multi_head_attention = self.cross_multi_head_attention(tokens, output_encoder, output_encoder)

        # Add & Norm レイヤー2の処理
        add_norm2 = self.add_norm2(cross_multi_head_attention, skip2)

        # Skip Connectionを保持しておく
        skip3 = add_norm2

        # FeedForward レイヤーの処理
        feed_forward = self.feed_forward(tokens)

        # Add & Norm レイヤー3の処理
        add_norm3 = self.add_norm3(feed_forward, skip3)

        # 処理結果を返す
        tokens = add_norm3

        return tokens
Decoder
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()

        # 複数のデコーダーレイヤーを構築
        self._setupDecoderLayer(batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)

    def _setupDecoderLayer(self, batch_size, num_patch, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        # デコーダーレイヤーをリストとして保持するためのモジュール
        decoder_layer_list = []

        # 指定された数だけデコーダーレイヤーを構築
        for _ in range(num_layer):
            decoder_layer = DecoderLayer(batch_size, num_patch, num_heads, seq_length, hidden_size, ffn_hidden_size)
            decoder_layer_list.append(decoder_layer)

        # モジュールリストとしてデコーダーレイヤーを保持
        self.decoder_module = nn.ModuleList(decoder_layer_list)

    def forward(self, decoder_embedding, encoder_output):
        # 入力トークンを保持しておく
        tokens = decoder_embedding

        # 各デコーダーレイヤーを順次適用
        for decoder_layer in self.decoder_module:
            tokens = decoder_layer(tokens, encoder_output)

        # デコーダーの出力を返す
        decoder_output = tokens

        return decoder_output

Decoder_layerではT5と比較しても、特別な処理は行っていません。Masked-MultiHeadAttentionとEncoderの出力ベクトルとのMulti-HeadCrossAttentionを行っています。

MultiHeadAttention
import torch
import torch.nn as nn
import numpy as np

class MultiHeadAttention(nn.Module):
    def __init__(self, batch_size, num_patch, num_heads, seq_length, hidden_size, check_positional_embedding, check_mask):
        super().__init__()

        # Query、Key、Value用の線形変換モジュールを構築
        self._setupHeadQKV(num_heads, hidden_size)

        # モジュールのパラメータや設定を保持
        self.batch_size = batch_size
        self.num_patch = num_patch
        self.num_heads = num_heads
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.check_positional_embedding = check_positional_embedding
        self.check_mask = check_mask

        # Softmax関数(Attentionの重み計算用)
        self.softmax = nn.Softmax(dim=-1)

    def _setupHeadQKV(self, num_heads, hidden_size):
        # Query、Key、Value用のモジュールをリストとして保持するためのモジュール
        query_module = []
        key_module = []
        value_module = []

        # ヘッドごとの隠れ層サイズ
        head_hidden_size = int(hidden_size / num_heads)

        # 指定されたヘッド数だけQuery、Key、Value用のモジュールを構築
        for _ in range(num_heads):
            query_module.append(nn.Linear(hidden_size, head_hidden_size))
            key_module.append(nn.Linear(hidden_size, head_hidden_size))
            value_module.append(nn.Linear(hidden_size, head_hidden_size))

        # モジュールリストとして保持
        self.query_module = nn.ModuleList(query_module)
        self.key_module = nn.ModuleList(key_module)
        self.value_module = nn.ModuleList(value_module)

    def _outputRelativePositionalEmbeddingScalar(self, query, batch_size, num_patch, seq_length, hidden_size, num_heads):
        # 相対位置エンベディングを計算する関数

        # 入力のシーケンス長を取得
        seq_length = seq_length[0]

        # Embeddingモジュールを保持するためのリスト
        embed_Module = []

        # ヘッドごとの隠れ層サイズ
        head_hidden_size = int(hidden_size / num_heads)

        # 位置情報のIDを作成
        position_ids = torch.tensor(list(range(num_patch + seq_length)), dtype=torch.long).reshape(1, num_patch + seq_length).expand(batch_size, num_patch + seq_length)

        # 各ヘッドごとにEmbeddingモジュールを構築
        for id in range(num_heads):
            embed_Module.append(nn.Embedding(num_patch + seq_length, head_hidden_size))

        # モジュールリストとしてEmbeddingモジュールを保持
        self.embed_module = nn.ModuleList(embed_Module)

        # 各ヘッドごとに相対位置エンベディングを計算
        for id in range(num_heads):
            head_query = self.query_module[id](query)
            tmp_relative_position_embedding_scalar = (head_query@(self.embed_module[id](position_ids).transpose(1, 2)))\
                .reshape(1, batch_size, num_patch + seq_length, num_patch + seq_length)[:, :, :num_patch, :num_patch]
            
            # 列のパディングを追加
            col_pad = torch.zeros((1, batch_size, num_patch, seq_length), dtype=torch.float)
            tmp_relative_position_embedding_scalar = torch.cat([tmp_relative_position_embedding_scalar, col_pad], dim=3)
            
            # 行のパディングを追加
            row_pad = torch.zeros((1, batch_size, seq_length, num_patch + seq_length))
            tmp_relative_position_embedding_scalar = torch.cat([tmp_relative_position_embedding_scalar, row_pad], dim=2)
            
            # 初めて計算するヘッドの場合は、相対位置エンベディングをそのまま保持
            if id == 0:
                relative_position_embedding_scalar = tmp_relative_position_embedding_scalar
            else:
                # すでに計算済みのヘッドがある場合は、テンソルを連結
                relative_position_embedding_scalar = torch.cat([relative_position_embedding_scalar, tmp_relative_position_embedding_scalar], dim=0)

        return relative_position_embedding_scalar
    
    def _outputAttention(self, query, key, value, batch_size, num_patch, seq_length, hidden_size, num_heads, check_positional_embedding, check_mask):
        # Attentionスコアを計算する関数

        # マスクや位置エンベディングの有無によって、処理を分岐
        if check_positional_embedding:
            seq_length1 = seq_length2 = num_patch + self.seq_length[0]
        else:
            if check_mask:
                seq_length1 = seq_length2 = self.seq_length[1]
            else:
                seq_length1 = self.seq_length[1]
                seq_length2 = num_patch + self.seq_length[0]

        # ヘッドごとの隠れ層サイズ
        head_hidden_size = int(hidden_size / num_heads)

        # マスクマップを作成
        mask_map = torch.tensor(np.tril(np.ones((seq_length1, seq_length2))), dtype=torch.long)

        # 位置エンベディングが指定されている場合、相対位置エンベディングを計算
        if check_positional_embedding:
            relative_position_embedding_scalar = self._outputRelativePositionalEmbeddingScalar(query, batch_size, num_patch, seq_length, hidden_size, num_heads)
        else:
            relative_position_embedding_scalar = None

        # 各ヘッドごとにAttentionスコアを計算
        for id in range(num_heads):
            head_query = self.query_module[id](query)
            head_key = self.key_module[id](key)
            head_value = self.value_module[id](value)

            if check_positional_embedding:
                # 位置エンベディングが指定されている場合、Attentionスコアに相対位置エンベディングを加算
                tmp_head_attention = self.softmax(((head_query@head_key.transpose(1, 2)) / (head_hidden_size) + relative_position_embedding_scalar[id]))@head_value
            else:
                # 位置エンベディングが指定されていない場合
                if check_mask:
                    # マスクが指定されている場合、Attentionスコアにマスクを適用
                    tmp_head_attention = self.softmax((mask_map * (head_query@head_key.transpose(1, 2)) / (head_hidden_size)))@head_value
                else:
                    # マスクが指定されておらず、位置エンベディングもない場合、通常のAttentionスコア計算
                    tmp_head_attention = self.softmax((head_query@head_key.transpose(1, 2)) / (head_hidden_size))@head_value

            # はじめて計算するヘッドの場合は、Attentionスコアをそのまま保持
            if id == 0:
                head_attention = tmp_head_attention
            else:
                # すでに計算済みのヘッドがある場合は、テンソルを連結
                head_attention = torch.cat([head_attention, tmp_head_attention], dim=-1)

        # 出力のAttentionスコアを返す
        output_attention = head_attention

        return output_attention

    def forward(self, query, key, value):
        # フォワード関数

        # Attentionスコアを計算
        output_attention = self._outputAttention(query, key, value, self.batch_size, self.num_patch, self.seq_length, self.hidden_size, self.num_heads,
                                                  self.check_positional_embedding, self.check_mask)

        return output_attention

このクラスでは、少し面倒な処理を行っています。T5のコードをを流用しているため複雑ですが、EncoderのAttentionには相対位置埋め込みを行い、DecoderにはMask化、またはCrossAttentionを行えるような分岐処理を施しています(普通に考えてそれぞれのAttentionごとに分けた方が可読性は高いです…)。

AddNorm
import torch.nn as nn

class AddNorm(nn.Module):
    def __init__(self, batch_size, num_patch, seq_length, hidden_size, check_encoder):
        super().__init__()

        # AddNormモジュールを構築
        self._setupAddNormModule(batch_size, num_patch, seq_length, hidden_size, check_encoder)

    def _setupAddNormModule(self, batch_size, num_patch, seq_length, hidden_size, check_encoder):
        # エンコーダーの場合、シーケンス長はエンコーダーのシーケンス長とパッチ数の合計
        if check_encoder:
            seq_length = seq_length[0] + num_patch
        else:
            # デコーダーの場合、シーケンス長はデコーダーのシーケンス長
            seq_length = seq_length[1]

        # LayerNormモジュールを構築
        self.layer_norm = nn.LayerNorm((batch_size, seq_length, hidden_size))

    def forward(self, tokens, skipped_tokens):
        # 入力トークンにスキップしたトークンを加算
        tokens += skipped_tokens

        # LayerNormを適用して出力トークンを生成
        tokens = self.layer_norm(tokens)

        return tokens
FeedForward
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, hidden_size, ffn_hidden_size):
        super().__init__()

        # FeedForwardモジュールを構築
        self._setupFeedForwardModule(hidden_size, ffn_hidden_size)

    def _setupFeedForwardModule(self, hidden_size, ffn_hidden_size):
        # 1つ目の全結合層とReLU活性化関数
        dense1 = nn.Linear(hidden_size, ffn_hidden_size)
        relu1 = nn.ReLU()

        # 2つ目の全結合層とReLU活性化関数
        dense2 = nn.Linear(ffn_hidden_size, hidden_size)
        relu2 = nn.ReLU()

        # モジュールリストとして保持
        self.feed_foward_module = nn.ModuleList([dense1, relu1, dense2, relu2])

    def forward(self, tokens):
        # フォワード関数

        # モジュールリスト内の各モジュールを順に適用
        for module in self.feed_foward_module:
            tokens = module(tokens)

        # 出力トークンを返す
        return tokens

4. 実行結果

実行
from transformers import AutoTokenizer
import numpy as np

# BERTトークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") # T5Tokenizerでは文頭トークン<s>が表示されなかったため、bertで代用

# 入力となる英語のトークン化
enocoder_input_tokens = ["Translate English to German : Two brown and white dogs", 
                         "Translate English to German : The man plaing soccoer",
                         "Translate English to German : The birds is flying"]
encoder_input_tokenize = tokenizer(enocoder_input_tokens, padding=True, return_tensors="pt", return_length=True)
encoder_input_ids = encoder_input_tokenize["input_ids"][:, 1:]  # 先頭のトークン<s>を除外
encoder_attention_mask = encoder_input_tokenize["attention_mask"][:, 1:]
encoder_attention_mask = torch.where(encoder_input_ids == 102, 0, encoder_attention_mask)  # 文末トークンを除外
encoder_input_ids = torch.where(encoder_input_ids == 102, 0, encoder_input_ids)  # 文末トークンを除外

# 出力となるドイツ語のトークン化
decoder_input_tokens = ["Zwei braune und weiße Hunde",
                       "Der Mann spielt Fußball",
                       "Die Vögel fliegen"]
decoder_input_tokenize = tokenizer(decoder_input_tokens, padding=True, return_tensors="pt", return_length=True)
decoder_input_ids = decoder_input_tokenize["input_ids"]
decoder_attention_mask = decoder_input_tokenize["attention_mask"]

encoder_max_length = encoder_input_tokenize["length"][0].item() - 1
decoder_max_length = decoder_input_tokenize["length"][0].item()

# 画像の生成と前処理
images = torch.randn(3, 3, 256, 256, dtype=torch.float)
patch_size = 16
num_patch = int((images.shape[2] / patch_size) ** 2) * images.shape[1]
num_embedding = torch.max(torch.concat([encoder_input_ids, decoder_input_ids], dim=-1)) + 1

# SimVLMモデルの初期化と実行
kwargs = {
    "batch_size": 3,
    "num_patch": num_patch,
    "num_embedding": num_embedding,
    "seq_length": (encoder_max_length, decoder_max_length),
    "hidden_size": 512,
    "num_layer": 12,
    "num_heads": 8,
    "ffn_hidden_size": 3072
}
simvlm = SimVLM(**kwargs)

# モデルの出力の形状を表示
print(encoder_max_length, decoder_max_length)
print(simvlm(images, encoder_input_ids, decoder_input_ids).shape)

論文ではEncoderへの入力では特殊トークンを付与すると記載されていなかったため、今回は除外しています。またDecoderには分類タスク等で必要なので付与しています。他設定はT5における基本的設定と同様となっています。patch_sizeは自由に決めてもらって構いません。

結果
14 12
torch.Size([3, 12, 512])

無事、正しくバッチサイズ × Decoderのトークン数 × 次元数 で出力されていますね。

まとめ

ここまで読んでくださり、ありがとうございます。

今回の記事では、SimVLMのスクラッチ開発をしてみました。T5のコードを大部分を流用できたので実装は比較的簡単でしたが、EncoderにおけるAttentionのヘッドごとの画像トークン部分にのみ相対位置埋め込みを適用するという処理に時間がかかりました。

しかし、やはり自分の手で実装することで処理の流れをより深いレベルで理解できますね。次は音声認識と自然言語処理に関する論文の実装を考えています(音声認識への理解はCTC程度で止まっていますが…)。

前回: 【論文読み】SimVLMの論文を要約してみた

参考文献

SimVLM: Simple Visual Language Model Pretraining with Weak Supervision

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0