LoginSignup
2
2

N番煎じでLLaVA v1.5をDatabricksで試す

Posted at

風邪で寝込んだりしていて出遅れてしまいましたが、試してみます。

導入

画像と言語のマルチモーダルモデルであるLLaVAのバージョン1.5が出ました。(ベースモデルがVicuna v1.5になったのかな?)
以下のように、既にいろんな方が試されていますが、N番煎じでやってみます。

確認環境はいつものようにDatabricksを利用します。

Step1. モデルダウンロード

Huggingfaceに公開されていますで、こちらからモデルをダウンロード。
今回は7Bのモデルを利用します。

import os
from huggingface_hub import snapshot_download

UC_VOLUME = "/Volumes/モデル保存先のボリューム"

model = "liuhaotian/llava-v1.5-7b"
local_dir = f"/tmp/{model}"
uc_dir = "/models--liuhaotian--llava-v1.5-7b"

snapshot_location = snapshot_download(
    repo_id=model,
    local_dir=local_dir,
    local_dir_use_symlinks=False,
)

dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}", recurse=True)

Step2. 環境セットアップ

githubから必要なコードセットをcloneして必要なモジュールをインストールします。

# 最新のtransformersにしておく
%pip install -U -qq transformers accelerate

# LLaVAのコードセットを取得&必要なモジュールのインストール
%sh mkdir /tmp/llava && cd /tmp/llava && git clone https://github.com/haotian-liu/LLaVA.git
%pip install -e /tmp/llava/LLaVA

dbutils.library.restartPython()

clone先にパスを通して準備完了。

import sys

sys.path.append("/tmp/llava/LLaVA")

model_path = "/Volumes/モデル保存先のボリューム/models--liuhaotian--llava-v1.5-7b"

Step3. 推論準備

CLI上で推論を行ってもいいのですが、せっかくなのでpython上で推論できるように準備します。
コードはこちらを参考にしました。

まずは、必要なモデルのロード。

from llava.utils import disable_torch_init
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path

# Model
disable_torch_init()

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path, None, model_name, load_8bit=True
)

画像の読み込みや推論処理用の関数を定義。
上記で取得したmodel等を関数内で利用しています。

import argparse
import torch

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import (
    tokenizer_image_token,
    KeywordsStoppingCriteria,
)

import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def generate_text(inp: str, image: Image = None):

    conv_mode = "llava_v1"

    conv = conv_templates[conv_mode].copy()
    roles = conv.roles

    image_tensor = None

    if image is not None:
        image_tensor = (
            image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
            .half()
            .cuda()
        )

        # first message
        if model.config.mm_use_im_start_end:
            inp = (
                DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_TOKEN
                + DEFAULT_IM_END_TOKEN
                + "\n"
                + inp
            )
        else:
            inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
        conv.append_message(conv.roles[0], inp)
    else:
        # later messages
        conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.2,
            max_new_tokens=256,
            streamer=streamer,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
        )
    outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
    conv.messages[-1][-1] = outputs

Step4. 推論

ようやく準備が整ったので、推論を実行しましょう。

今回はいらすとやさんの画像を入力し、説明させてみます。
ちなみに選んだ画像は適当です。

まずは、こちらのプロポーズしようとしている人のイラスト

url = "https://2.bp.blogspot.com/-8og8ypK6lJg/XAY_Gj-3jYI/AAAAAAABQlE/08TO5edQ1lcEiepb9eZ9Z3hck0AWqMYLwCLcBGAs/s800/wedding_propose_surprise_suit.png"
image = load_image(url)

generate_text("何の画像ですか?", image)
出力
この画像は、2人の人物が描かれており、その中には手をつかんでいる人物がいます。その人物は、手のひらにコンタクトレスを持っていることが分かります。また、2人の人物は、手をつないでいるように描かれており、それぞれが異なる色の服を着ています。

男性の持ってるものが意味不明ですね。
もう少し突っ込んでみましょう。

generate_text("男性が手に持っているのはダイヤの指輪のようです。それを踏まえて、これはどういったシチュエーションの画像でしょうか?", image)
出力
この画像は、二人の人物が会話をするシチュエーションです。男性は手にダイヤの指輪を持っていますが、それは彼らの会話の中であまり重要な役割を果たしていません。彼らの会話は、主に表情や態度で伝えられているようです。このシチュエーションでは、ダイヤの指輪はあくまでも背景にあるものであり、会話の主要なテーマは二人の関係性やコミュニケーションにあると考えられます。

さすがにそこまで空気は読めないか・・・!

違う画像にいってみます。

こちらの優勝パレードのイラスト(スーツ)です。

url = "https://3.bp.blogspot.com/-pwAVfrr360g/W9Rcl7R42ZI/AAAAAAABPr4/AIjYWGuaaZ8QbrpTu3YFqYB14tTA2u40wCLcBGAs/s800/parade_bus_suit_man.png"
image = load_image(url)

generate_text("何の画像ですか?", image)
出力
この画像は、5人の男性が乗っているバスのイラストレーションです。彼らはすべて統一された服装で、細かな手柄を振り上げています。バスは大きな赤いバスで、5人の人物がそれに乗っています。

日本語が少し微妙ですが、かなり正しく認識されていると思います。

まとめ

LLMを動かすのと同程度には簡単に動かすことができました。
検証は十分ではないかもしれませんが、かなり正しく画像認識されるのではないかと思います。
(おそらく13Bモデルを使えば、より正確に)

設計書の画像からコードを作り出すような試みもあるようですし、GPT4-Vの登場もあって、この分野どんどん脚光が浴びていく気がします。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2