67
50

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLMAdvent Calendar 2023

Day 2

日本語LLMでLLaVAの学習を行ってみた

Last updated at Posted at 2023-12-01

はじめに

本記事はLLM Advent Calendar 2023 2日目の記事になります。

最近、様々なLLMが発表されたことによりローカルLLM界隈では自作データセットを作成して自分好みのLLMを作成するなど日本語LLM界隈は盛り上がりを見せています。

一方、マルチモーダルなLLMとして画像を組み合わせたものに関してはTuring、Stability AI、Rinnaなどの企業ではモデルを公開していますが、個人で行われている方は少ないという印象があります。

そこで今回はLLaVAと同じ方法で日本語LLMを学習させて、個人でマルチモーダルなLLMの学習を行ってみました。個人で学習できる範疇ということで学習はRTX4090 1台で行っています。

学習に使用したコードは以下で公開しています。

モデルは以下で公開しています。

事前学習モデル:

ファインチューニングモデル:

事前学習に使用した日本語翻訳データセットは以下で公開しています。

LLaVAとは?

Large Language and Vision Assistant、LLaVAは2023年4月に発表された、LLMを使用したVision and languageモデルです。

その後10月に発表されたLLaVA-1.5では11個のベンチマークでSoTAを達成したと報告されています。

LLaVAのモデル構造

LLaVAのモデル構造は以下の図の通りシンプルでVision Encoder + Vision Projector + LLMから構成されています。

llava_arch.png

Visual Instruction Tuning, Liu, H. et al. (2023)

Vision EncoderにはOpenAIのCLIP VIT-L-patch14-336が使用されています。

Vision ProjectorはVision Encoderの出力をLLMの入力に変換する役割を持っており線形層が使用されています。LLaVA1.5では一層の線形層より性能が向上するということで二層の線形層が使用されています。また、活性化関数にはGELUが使用されています。

LLMにはVicuba-13b-1.5が使用されおり、本記事ではこの部分を日本語データで事前学習されたLLMに置き換えます。

学習方法

学習は二段階のステップで行われています。

一段階目の事前学習ではVision EncoderとLLMを凍結し、Vision Projectorのみを学習させています。

二段階目のファインチューニングではVision Encoderのみを凍結し、Vision ProjectorとLLMを学習させています。

事前学習データ

LLaVAの大きな特徴は他のモデルと比べて事前学習データの数が少ないところです。

Instruct BLIPでは事前学習には約129Mのデータが使用されていますがLLaVAでは約600Kのデータしか使用していません(発表当初の論文ではCC3Mから抽出した595Kのデータセット、LLaVA-1.5ではLAION/CC/SBUから抽出した558Kのデータセット)。

LLaVA_traindata.png

Improved Baselines with Visual Instruction Tuning, Liu, H. et al. (2023)

ファインチューニングデータ

ファインチューニングに使用するデータとしてGPT4を使用してマルチターンの会話を行う58Kのデータと画像について詳細な説明を行う23Kのデータと複雑な推論を行う77Kの計158Kのデータセットを作成しています。

こちらを日本語に翻訳したデータセットはTuringによって公開されています。

また、LLaVA1.5では従来の学習データのフォーマットでは単語一つで回答してほしい場合に制御ができなかったため"Answer the question using a single word or phrase."のようなプロンプトを使用して回答を制御するような工夫がなされています。

LLaVA_Instruct.png

Improved Baselines with Visual Instruction Tuning, Liu, H. et al. (2023)

学習データの入力方法

第一段階ではLLMへの入力として画像特徴量をVision Projectorで投影させたもののみをLLMに入力しています。

第二段階では発表当初の論文ではImageトークンをインストラクションの前後につけていたみたいですが、実装を見ていると必ず先頭につけるように変更されていました。

同じ疑問を持った方がいたみたいで以下のISSUEで学習にあまり意味がないから変更したと著者の方から回答がありました。

LLaVAのコードを日本語モデル向けに改修する

今回はLLaVA-1.5のコードをベースに日本語向けに改造していきます。改修後のコードは公開していますので、主要な部分のみ説明していきます。

LLaVA-1.5ではLLMがVicuna-13b-v1.5に対応しているためyouri-7b等のLlama2ベースのLLMに対してはそのまま学習を行うことも可能です。

ただLlama2ベースのモデルは7B以上のサイズのものばかりであるため個人が保有するGPUで学習するのは困難です。そこで、今回はパラメータ数が1bほどのモデルが多いGPT2ベースのモデルで学習できるようにコードを改修します。

llava/model/language_model/llava_llama.pyのLLavaConfig、LlavaLlamaModel、LlavaLlamaForCausalLMクラスのLlama系のクラスを継承している部分をGPT2系のクラスを継承するように変更します。

from transformers import GPT2LMHeadModel, GPT2Config

class LlavaConfig(GPT2Config):
    model_type = "llava"


class LlavaGpt2Model(LlavaMetaModel, GPT2LMHeadModel):
    config_class = LlavaConfig

    def __init__(self, config: GPT2Config):
        super(LlavaGpt2Model, self).__init__(config)

class LlavaGpt2ForCausalLM(GPT2LMHeadModel, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlavaGpt2ForCausalLM, self).__init__(config)
        self.model = LlavaGpt2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
後略
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

次にllava/model/llava_arch.pyのprepare_inputs_labels_for_multimodalメソッドを変更します。変更点はEmbedding処理の呼び出し名が、LlavaLlamaModelを継承している場合とGPT2LMHeadModelを継承している場合では異なるためその部分を修正していきます。

具体的には

self.get_model().embed_tokens(

の部分を

self.get_model().transformer.wte(

に置き換えるだけです。

class LlavaMetaForCausalLM(ABC):

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
前略
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    def prepare_inputs_labels_for_multimodal(
        self, input_ids, position_ids, attention_mask, past_key_values, labels, images
    ):

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
中略
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        
        for batch_idx, cur_input_ids in enumerate(input_ids):
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_image_features = image_features[cur_image_idx]
                cur_input_embeds_1 = self.get_model().transformer.wte(cur_input_ids)
                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
                new_input_embeds.append(cur_input_embeds)
                new_labels.append(labels[batch_idx])
                cur_image_idx += 1
                continue

            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            cur_labels = labels[batch_idx]
            cur_labels_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
                cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
            split_sizes = [x.shape[0] for x in cur_labels_noim]
            cur_input_embeds = self.get_model().transformer.wte(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []
            cur_new_labels = []

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
後略
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

モデルの改修は以上となります。

GPTNeoXベースのモデルで学習させたい場合もGPT2ベース用に回収した部分をGPTNeoX系のモデルに合わせるだけで改修可能です。

LLaVA-JPの学習について

開発環境

今回学習に使用した環境は以下のとおりです。

項目 内容
OS Ubuntu22.04
CPU Ryzen 9 7900X
GPU RTX4090 24GB

使用する言語モデル

使用する言語モデルはRTX4090(24GB)に乗せることができるモデルということでパラメータ数が1.3Bのモデルからllm-jp/llm-jp-1.3b-v1.0を選択しました。

当初は様々なパラメータのモデルが提供されているということでrinna/japanese-gpt-1bを使用していたのですが、Lossの値が下がらず学習後の出力結果もあまり良くなかったため途中から使用するモデルをllm-jp-1.3b-v1.0に変更しました。

また、以下のISSUEにあるようにFP16で学習させると勾配がオーバーフローしてLossがnanになるという問題が発生したためBF16で学習を行っています。

データセット

事前学習データ

事前学習データには以下の2つを使用しています。

LLaVA-CC3M-Pretrain-595K-JAはLLaVA-CC3M-Pretrain-595Kを日本語訳したデータになります。翻訳にはcyberagent/calm2-7b-chatを使用しています。

翻訳にcalm2-7b-chatを採用した理由はyouriなどと違ってLlamaを継続学習しているわけではないため、出力データを学習に使用できるという理由で選択しました。翻訳時のプロンプトには以下のものを使用しています。

# inputが英語の入力
f"""USER: 下記の英語を日本語に翻訳してください。
{input}
ASSISTANT: """

また、LLaVA-CC3M-Pretrain-595Kにはhumanの入力として"Create a compact narrative representing the image presented."のような指示+前後にimage tokenがつくというものが20パターンありますが、以下の表のとおり変換しています(事前学習にはこの指示を使っていないので変換する必要はありませんが…)。

Source 日本語訳
Create a compact narrative representing the image presented. 与えられた画像を表す簡潔な文を作成してください。
Describe the image concisely. 画像について簡単に説明してください。
Provide a brief description of the given image. 与えられた画像について簡単に説明してください。
Offer a succinct explanation of the picture presented. 入力された写真について簡単に説明してください。
Summarize the visual content of the image. 画像の内容を教えてください。
Give a short and clear explanation of the subsequent image. 次の画像を短く分かりやすく説明してください。
Share a concise interpretation of the image provided. 与えられた画像について教えてください。
Give a short and clear explanation of the subsequent image. 写真の特徴を手短に教えてください。
Relay a brief, clear account of the picture shown. この画像について短い言葉で説明してください。
Render a clear and concise summary of the photo. 写真の概要を簡潔かつ分かりやすく伝えてください。
Write a terse but informative summary of the picture. 写真について簡単に説明してください。

ファインチューニングデータ

ファインチューニングデータには以下の2つを使用しています。

言語モデルと学習データの組み合わせ

上記の言語モデルと学習データを組み合わせて以下の2パターンを試しています。モデル名は後の比較で分かりやすいように仮でつけています。

モデル名 言語モデル STAIR Captions LLaVA-CC3M-Pretrain-595K-JA Japanese Visual Genome VQA dataset LLaVA-Instruct-150K-JA
llava-jp-only-stair llm-jp-1.3b-v1.0    
llava-jp-full llm-jp-1.3b-v1.0

他のパターンも試したかったのですが時間が足りず試すことができませんでした。

また、第二段階のファインチューニングをLoRAで学習させるコードも元の実装にありGPT2向けに改修はしたのですが、学習まで行うことができていないという状態です。

学習結果

事前学習

pretrain_loss1.png

Lossの値はStair Captionだけを事前学習データとして使用したほうが少し下がるという結果でした。これはLLaVA-CC3M-Pretrain-595K-JAのCaptionに翻訳ミスがあるのが原因だと思います。

ファインチューニング

finetuning_loss1.png

ファインチューニング後は両者ともLossが1.0前後まで下がりました。

LLaVAの結果と比べると少し高いですが、これはモデルのパラメータ数や言語モデル自身の事前学習データの差なのかなと感じています。

ちなみにjapanese-gpt-1bで試したときはLossが2.7ほどまでしか下がりませんでした。

出力確認

最後に学習したモデルの出力を確認していきます。今回自身で学習した2つモデルの他にturing-motors/heron-chat-blip-ja-stablelm-base-7b-v0stabilityai/japanese-stable-vlmと比較していきたいと思います。

指示は同じにしますが、プロンプトはモデルに合わせて変更しています。

またパラメータに関してはtemperature=0.01、top_p=1.0で合わせます。

サンプル1

最初に以下の画像で試してみます。
sample2.jpg

入力:猫は何をしていますか?

モデル名 出力
llava-jp-only-stair
llava-jp-full 寝ている
heron-chat-blip ノートパソコンの上で寝ている
stable-vlm 猫は何をしていますか?

入力:猫は何色ですか?

モデル名 出力
llava-jp-only-stair 黒茶色色
llava-jp-full 茶色
heron-chat-blip 茶色
stable-vlm 猫は何色ですか?

入力:猫の隣には何がありますか?

モデル名 出力
llava-jp-only-stair ノート
llava-jp-full ノートパソコン
heron-chat-blip 猫はノートパソコンの上に座っている。
stable-vlm ノートパソコン

Visual Genome VQAに含まれているような単純な質問に関してはllava-jp-only-stairもllava-jp-fullもそこそこいい感じに答えてくれています。

llava-jp-fullに関してはheron-chat-blipやstable-vlmと比べてもパラメータ数が少ないにも関わらずいい回答をしてくれている気がします。

サンプル2

次に定番の以下の画像で試してみます。
sample1.jpg

入力:この画像の男性は何をしていますか?

モデル名 出力
llava-jp-only-stair サーフィン
llava-jp-full サーフィン
heron-chat-blip 画像の男性は、黄色いトラックの荷台で洗濯物を干している。
stable-vlm 車の上で洗濯をしている男性

この入力ではllava-jp-only-stairとllava-jp-fullはアイロン台をサーフボードと勘違いしたのかサーフィンと答えています。

一方、heron-chat-blipとstable-vlmではアイロンがけではなく洗濯と答えていますが、学習させたものより近い答えを出している気がします。

入力:この画像の面白い点を教えてください?

モデル名 出力
llava-jp-only-stair
llava-jp-full 画像の中で、黄色いシャツを着た男性が、車の荷台に座って洗濯機を使っている。このシーンは、男性が日常生活の中で洗濯機を使っていることを示唆している。この男性は、おそらくは都市部で、おそらくは公共交通機関を利用して、洗濯機を使って服を洗濯しているのだろう。このシーンは、日常生活の中で洗濯機を使うことの重要性を強調している。
heron-chat-blip 画像では、黄色いトラックの荷台で洗濯物を干している男性が描かれている。彼はトラックに寄りかかり、洗濯物を取り込んでいる。このシーンは、男性がトラックで仕事をしているか、トラックを運転していることを示唆している。トラックは、このシーンの中心的な焦点であり、男性の作業スペースを提供しているように見える。背景には、道路を走る複数の車が見え、おそらくこの地域で交通渋滞が発生している可能性がある。
stable-vlm 男は車の上で洗濯をしている

llava-jp-only-stairに関してはSTAIR Captionsには存在しないような質問ということでうまく答えられていません。

llava-jp-fullとheron-chat-blipに関しては似たような答えを出力しているような気がします。先ほどはllava-jp-fullはサーフィンと答えましたが今度は洗濯というふうに認識しているみたいです。

stable-vlmは学習データにLLaVA-Instruct-150K-JAがないためか簡潔な回答になっています。

どのモデルも洗濯だと認識しているのが面白いですね。

まとめ

本記事ではLLM-jp-1.3b-v1.0をLLaVAの手法で学習させてマルチモーダルなLLMを作成しました。

きちんとしたベンチマークで比較したわけではなく、簡単な比較しかできていませんが正直想像以上の性能でした。
(記事を書きながら比較していたため、本当はパラメータ数足らないから学習失敗したみたいな結論にする予定が急遽変更になりました…)

学習させてみてLLaVA-Instruct-150K-JAのようなデータって大切だなと感じました。このデータはGPT4で作成されたということもあり商用利用ができないため、マルチモーダルな日本語LLMを盛り上げるなら商用可能な似たようなデータの作成が必須だなと感じました。

また、今回LLaVAの学習手法が良かったというのもありますが、それと同じぐらいLLM-jp-1.3b-v1.0の性能が良かったというのも学習がうまく行ったポイントだと思います。

最近、パラメータ数が7B以上のモデルは色々と発表されていますが、パラメータ数の小さなモデルの発表数は少なくなってきているので個人的にはこういった小さなモデルが増えたらいいなと思います。

最初にも述べたとおり学習コードは以下に公開しています。

オリジナルのLLaVAの実装と比べるとDeepSpeedを使用しないようにしていたり、不要なコードは削除するなど少し簡易化しています。

RTX4090よりいいGPUを持っている方はよりパラメータ数の大きいモデルでの学習などを試してみてください。

最後まで読んでいただきありがとうございました。記事に誤り等ございましたら指摘していただければ幸いです。

TODO

  • LLaVA-CC3M-Pretrain-595K-JAの公開
  • llava-jp-fullの公開(モデル名はllava-jp-1.3b-v1.0にしました)
  • LoRAを使ったファインチューニングの実施
67
50
5

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
67
50

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?