風邪で寝込んだりしていて出遅れてしまいましたが、試してみます。
導入
画像と言語のマルチモーダルモデルであるLLaVAのバージョン1.5が出ました。(ベースモデルがVicuna v1.5になったのかな?)
以下のように、既にいろんな方が試されていますが、N番煎じでやってみます。
確認環境はいつものようにDatabricksを利用します。
Step1. モデルダウンロード
Huggingfaceに公開されていますで、こちらからモデルをダウンロード。
今回は7Bのモデルを利用します。
import os
from huggingface_hub import snapshot_download
UC_VOLUME = "/Volumes/モデル保存先のボリューム"
model = "liuhaotian/llava-v1.5-7b"
local_dir = f"/tmp/{model}"
uc_dir = "/models--liuhaotian--llava-v1.5-7b"
snapshot_location = snapshot_download(
repo_id=model,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}", recurse=True)
Step2. 環境セットアップ
githubから必要なコードセットをcloneして必要なモジュールをインストールします。
# 最新のtransformersにしておく
%pip install -U -qq transformers accelerate
# LLaVAのコードセットを取得&必要なモジュールのインストール
%sh mkdir /tmp/llava && cd /tmp/llava && git clone https://github.com/haotian-liu/LLaVA.git
%pip install -e /tmp/llava/LLaVA
dbutils.library.restartPython()
clone先にパスを通して準備完了。
import sys
sys.path.append("/tmp/llava/LLaVA")
model_path = "/Volumes/モデル保存先のボリューム/models--liuhaotian--llava-v1.5-7b"
Step3. 推論準備
CLI上で推論を行ってもいいのですが、せっかくなのでpython上で推論できるように準備します。
コードはこちらを参考にしました。
まずは、必要なモデルのロード。
from llava.utils import disable_torch_init
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
# Model
disable_torch_init()
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, None, model_name, load_8bit=True
)
画像の読み込みや推論処理用の関数を定義。
上記で取得したmodel等を関数内で利用しています。
import argparse
import torch
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import (
tokenizer_image_token,
KeywordsStoppingCriteria,
)
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def generate_text(inp: str, image: Image = None):
conv_mode = "llava_v1"
conv = conv_templates[conv_mode].copy()
roles = conv.roles
image_tensor = None
if image is not None:
image_tensor = (
image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
.half()
.cuda()
)
# first message
if model.config.mm_use_im_start_end:
inp = (
DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_TOKEN
+ DEFAULT_IM_END_TOKEN
+ "\n"
+ inp
)
else:
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
conv.append_message(conv.roles[0], inp)
else:
# later messages
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
max_new_tokens=256,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria],
)
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
conv.messages[-1][-1] = outputs
Step4. 推論
ようやく準備が整ったので、推論を実行しましょう。
今回はいらすとやさんの画像を入力し、説明させてみます。
ちなみに選んだ画像は適当です。
まずは、こちらのプロポーズしようとしている人のイラスト。
url = "https://2.bp.blogspot.com/-8og8ypK6lJg/XAY_Gj-3jYI/AAAAAAABQlE/08TO5edQ1lcEiepb9eZ9Z3hck0AWqMYLwCLcBGAs/s800/wedding_propose_surprise_suit.png"
image = load_image(url)
generate_text("何の画像ですか?", image)
この画像は、2人の人物が描かれており、その中には手をつかんでいる人物がいます。その人物は、手のひらにコンタクトレスを持っていることが分かります。また、2人の人物は、手をつないでいるように描かれており、それぞれが異なる色の服を着ています。
男性の持ってるものが意味不明ですね。
もう少し突っ込んでみましょう。
generate_text("男性が手に持っているのはダイヤの指輪のようです。それを踏まえて、これはどういったシチュエーションの画像でしょうか?", image)
この画像は、二人の人物が会話をするシチュエーションです。男性は手にダイヤの指輪を持っていますが、それは彼らの会話の中であまり重要な役割を果たしていません。彼らの会話は、主に表情や態度で伝えられているようです。このシチュエーションでは、ダイヤの指輪はあくまでも背景にあるものであり、会話の主要なテーマは二人の関係性やコミュニケーションにあると考えられます。
さすがにそこまで空気は読めないか・・・!
違う画像にいってみます。
こちらの優勝パレードのイラスト(スーツ)です。
url = "https://3.bp.blogspot.com/-pwAVfrr360g/W9Rcl7R42ZI/AAAAAAABPr4/AIjYWGuaaZ8QbrpTu3YFqYB14tTA2u40wCLcBGAs/s800/parade_bus_suit_man.png"
image = load_image(url)
generate_text("何の画像ですか?", image)
この画像は、5人の男性が乗っているバスのイラストレーションです。彼らはすべて統一された服装で、細かな手柄を振り上げています。バスは大きな赤いバスで、5人の人物がそれに乗っています。
日本語が少し微妙ですが、かなり正しく認識されていると思います。
まとめ
LLMを動かすのと同程度には簡単に動かすことができました。
検証は十分ではないかもしれませんが、かなり正しく画像認識されるのではないかと思います。
(おそらく13Bモデルを使えば、より正確に)
設計書の画像からコードを作り出すような試みもあるようですし、GPT4-Vの登場もあって、この分野どんどん脚光が浴びていく気がします。