2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

MobileVLM V2の学習を日本語データで試してみる

Posted at

はじめに

LLMをデコーダとして使用した、Vision-Language Model(以下VLM)がMiniGPT-4LLaVAをきっかけに様々なモデルが発表されています。

また、最近ではより小さなLLMを使用した、TinyGPT-VMobileVLM等も発表されています。

色々な手法が提案されている中、2024年2月上旬に発表されたMobileVLM V2がProjectorに工夫を入れることで推論速度や学習効率を改善していて個人的に面白いなと感じたのでモデルのアーキテクチャについて紹介していきます。

紹介するだけでは面白くないので、本記事では日本語による学習を試してみて簡単にLLaVA-1.5のアーキテクチャとの性能比較を行います。

MobileVLM V2とは?

MobileVLM V2とは前身であるMobileVLMから学習データ、学習戦略、モデルのアーキテクチャを改良することで、小さなパラメータのモデル(1.7Bや3B)であるにも関わらず7Bほどのパラメータ数を持つモデル(LLaVA1.5 7B)と同等の性能を達成したという手法です。

モデルのアーキテクチャ

モデルのアーキテクチャはVision EncoderやLLMの部分は、LLMに独自学習を行ったMobileLLaMAを使用していること以外はLLaVAと違いはありません。

異なるのはProjectorの部分です。

image.png

MobileVLM V2: Faster and Stronger Baseline for Vision Language Model, Chu, X. et al. (2024)

MobileVLM V2のProjectorにはLightweight Downsample Projector v2(LDPv2)という名前がつけられています。

名前の通りAvgPoolを使用することで24x24から12x12へダウンサンプルを行っています。

また通常のConvではなくPoint-wise Conv → Depth-wise Convを使用すること(実装としては線形層→AvgPool→Depth-wise Conv)で軽量な実装となっています。
(Depth-wise ConvはDepthwise Separable Convolutionについて分かりやすく解説!に分かりやすい解説があります)

工夫ってそれだけと思うかもしれませんが、Projectorに線形層のみを使用した場合と比べて、性能を落とさずにデコーダへの入力トークン数を減らすことができるため、学習と推論が効率化できているのがすごいなと思いました。

また、シンプルで実装しやすいというのも開発者側からしたらありがたいです。

日本語で学習させてみる

実装はLLaVA-JPのコードをベースに行います(自分のリポジトリの宣伝…)。

上記で紹介したMobileVLM V2も実装が公開されており、Projector部分はMobileVLM/mobilevlm/model/vision_projector.pyのLDPNetV2Projectorクラスに実装されているため、関連している部分をLLaVA-JPのLLaVA-JP/llava/model/vision_projector.pyにコピペして、LDPNetV2Projectorクラスを呼び出せるようにするだけで実装終わりです。

vision_projector.py

~~~~~~~~~~~~~~~~~~前略~~~~~~~~~~~~~~~~~~

class FeatureIRLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


class TokenDownLayer(nn.Module):
    def __init__(self, shape) -> None:
        super().__init__()
        self.dwn = nn.Sequential(
            nn.AdaptiveAvgPool2d(shape)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, num_tokens, c = x.shape
        h = int(math.sqrt(num_tokens))
        assert h * h == num_tokens
        x = x.permute(0, 2, 1).reshape(b, -1, h, h)
        x = self.dwn(x)
        x = x.flatten(2).transpose(1, 2)
        return x
    

class PosInjectLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, stride: int = 1) -> None:
        super().__init__()
        self.peg = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 3, stride, 1, bias=True, groups=out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, num_tokens, c = x.shape
        h = int(math.sqrt(num_tokens))
        assert h * h == num_tokens
        cnn_feat = x.transpose(1, 2).view(b, c, h, h)
        x = self.peg(cnn_feat) + cnn_feat
        x = x.flatten(2).transpose(1, 2)
        return x


class LDPNetV2Projector(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        inc, ouc = config.mm_hidden_size, config.hidden_size
        self.mlp = FeatureIRLayer(inc, ouc)
        self.dwn = TokenDownLayer((12, 12))
        self.peg = PosInjectLayer(ouc, ouc, stride=1)

    def forward(self, x):
        x = self.mlp(x)
        x = self.dwn(x)
        x = self.peg(x)
        return x
    

def get_vision_projector(config, delay_load=False, **kwargs):
    projector_type = getattr(config, 'mm_projector_type', 'linear')

    if projector_type == 'linear':
        return nn.Linear(config.mm_hidden_size, config.hidden_size)
    elif projector_type == 'identity':
        return IdentityMap()
    elif projector_type == 'ldpnetv2':
        return LDPNetV2Projector(config)

    ~~~~~~~~~~~~~~~~~~後略~~~~~~~~~~~~~~~~~~

評価

LLaVA-1.5とMobileVLM V2のアーキテクチャを使用し、日本語LLMで学習させたモデルを用いて簡易的な評価を行います。

以下の3点を評価対象とします。

  • 学習速度
  • 推論速度
  • 性能

評価用の環境は以下のとおりです。

項目 内容
OS Ubuntu22.04
CPU Ryzen 9 7900X
GPU RTX4090 24GB

モデルの構造はProjector部をLLaVAとMobileVLM V2の実装に合わせて変更し、その他のVision EncoderとLLMは以下を使用しています。

項目 モデル
Vision Encoder openai/clip-vit-large-patch14-336
LLM llm-jp/llm-jp-1.3b-v1.0

学習や評価用のデータセットとしては以下のものを使用しました。

項目 学習フェーズ データセット
学習 Pretarin LLaVA-CC3M-Pretrain-595K-JA
学習 Fine-Tuning STAIR Captionsのtrainデータ
評価 - STAIR Captionsのvalidationデータ

また、LLaVA-1.5に習ってPretrain時はProjectorだけを学習。Fine-Tuning時はProjectorとLLMを学習させています。
(MobileVLM V2の論文はPretrainでもLLMを学習させると性能が上がると言われていますが比較のために上記のようにしています)

学習速度

学習はPretrainとFine-Tuningともに1epochずつ行っており、DeepSpeedのような学習を高速化させるライブラリは使用していません。

Figure_1.png

Figure_2.png

PretrainではMobileVLM V2のアーキテクチャはLLaVA-1.5の約0.33倍の速さで学習が完了しています。

また、Fine-TuningでもPretrainのときほどの差はありませんが0.64倍の速さで学習が終わりました。

推論速度

GPUでは1000回、CPUでは100回キャプションを生成を行い推論速度を計算しています。推論速度を向上させるようなライブラリや設定は利用せず、重みはBF16でロードしています。

Figure_3.png

Figure_4.png

推論速度もMobile VLM V2のアーキテクチャの方がCPUとGPUともに少し早いという結果になりました。

性能

性能はStair Captionsで微調整しているためBLEU、ROUGE、METEOR、CIDArというキャプション生成でよく用いられる評価指標を使用して評価します。

アーキテクチャ BLEU1 BLEU2 BLEU3 BLEU4 ROUGE METEOR CIDAR
LLaVA-1.5 0.806 0.660 0.529 0.425 0.579 0.338 1.120
MobileVLM V2 0.800 0.652 0.519 0.415 0.574 0.335 1.093

キャプション生成による評価では全ての評価指標でLLaVA-1.5のアーキテクチャを用いたモデルの方が性能が高かったです。

ただ、そこまで大きな差があるわけではなく個人的には許容範囲かなと思いました。

また速度が少し落ちてもいいと思うならAveragePoolを12x12ではなく、16x16のようにサイズを少し大きくすることで性能差は減らせる可能性がありそうだと感じました。

まとめ

MobileVLM V2で提案された、LDPv2を使用することで以下のようになることがわかりました。

  • 学習速度は大幅に向上
  • 推論速度はわずかに向上
  • キャプション生成の性能はわずかに低下

個人で学習させている身としては大量のGPUを使用できないため、学習速度が大幅に向上するというのはかなり嬉しいなと感じました。

また、性能はキャプション生成のタスクで下回ったのを確認しただけのため、他のタスクも込みで学習させると結果は変わるかもしれません。

VLM関係は最近かなり盛り上がっているので、今後さらにいいアーキテクチャのモデルが提案されることに期待です。

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?