こちらの記事のノートブックをDatabricksでウォークスルーします。
ノートブックはこちらで公開されています。
LLaVAによるビジョンチャットアシスタントの作成
依存関係のインストールとインポート
!pip install git+https://github.com/haotian-liu/LLaVA.git@786aa6a19ea10edc6f574ad2e16276974e9aaa3a
dbutils.library.restartPython()
from transformers import AutoTokenizer, BitsAndBytesConfig
from llava.model import LlavaLlamaForCausalLM
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
from llava.conversation import conv_templates, SeparatorStyle
import torch
from PIL import Image
import requests
from io import BytesIO
チャットbotクラスの定義
class LLaVAChatBot:
def __init__(self,
model_path: str = 'liuhaotian/llava-v1.5-7b',
device_map: str = 'auto',
load_in_8_bit: bool = True,
**quant_kwargs) -> None:
self.model = None
self.tokenizer = None
self.image_processor = None
self.conv = None
self.conv_img = None
self.img_tensor = None
self.roles = None
self.stop_key = None
self.load_models(model_path,
device_map=device_map,
load_in_8_bit=load_in_8_bit,
**quant_kwargs)
def load_models(self, model_path: str,
device_map: str,
load_in_8_bit: bool,
**quant_kwargs) -> None:
"""Load the model, processor and tokenizer."""
quant_cfg = BitsAndBytesConfig(**quant_kwargs)
self.model = LlavaLlamaForCausalLM.from_pretrained(model_path,
low_cpu_mem_usage=True,
device_map=device_map,
load_in_8bit=load_in_8_bit,
quantization_config=quant_cfg)
self.tokenizer = AutoTokenizer.from_pretrained(model_path,
use_fast=False)
vision_tower = self.model.get_vision_tower()
vision_tower.load_model()
vision_tower.to(device='cuda')
self.image_processor = vision_tower.image_processor
disable_torch_init()
def setup_image(self, img_path: str) -> None:
"""画像のロードと処理"""
if img_path.startswith('http') or img_path.startswith('https'):
response = requests.get(img_path)
self.conv_img = Image.open(BytesIO(response.content)).convert('RGB')
else:
self.conv_img = Image.open(img_path).convert('RGB')
self.img_tensor = self.image_processor.preprocess(self.conv_img,
return_tensors='pt'
)['pixel_values'].half().cuda()
def generate_answer(self, **kwargs) -> str:
"""現在の会話から回答を生成"""
raw_prompt = self.conv.get_prompt()
input_ids = tokenizer_image_token(raw_prompt,
self.tokenizer,
IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).cuda()
stopping = KeywordsStoppingCriteria([self.stop_key],
self.tokenizer,
input_ids)
with torch.inference_mode():
output_ids = self.model.generate(input_ids,
images=self.img_tensor,
stopping_criteria=[stopping],
**kwargs)
outputs = self.tokenizer.decode(
output_ids[0, input_ids.shape[1]:]
).strip()
self.conv.messages[-1][-1] = outputs
return outputs.rsplit('</s>', 1)[0]
def get_conv_text(self) -> str:
"""完全な会話のテキストを返却"""
return self.conv.get_prompt()
def start_new_chat(self,
img_path: str,
prompt: str,
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
use_cache=True,
**kwargs) -> str:
"""新たな画像で新規チャットを開始"""
conv_mode = "v1"
self.setup_image(img_path)
self.conv = conv_templates[conv_mode].copy()
self.roles = self.conv.roles
first_input = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
DEFAULT_IM_END_TOKEN + '\n' + prompt)
self.conv.append_message(self.roles[0], first_input)
self.conv.append_message(self.roles[1], None)
if self.conv.sep_style == SeparatorStyle.TWO:
self.stop_key = self.conv.sep2
else:
self.stop_key = self.conv.sep
answer = self.generate_answer(do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
use_cache=use_cache,
**kwargs)
return answer
def continue_chat(self,
prompt: str,
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
use_cache=True,
**kwargs) -> str:
"""既存のチャットの継続"""
if self.conv is None:
raise RuntimeError("No existing conversation found. Start a new"
"conversation using the `start_new_chat` method.")
self.conv.append_message(self.roles[0], prompt)
self.conv.append_message(self.roles[1], None)
answer = self.generate_answer(do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
use_cache=use_cache,
**kwargs)
return answer
モデルのセットアップ
モデルをダウンロードし、チャットボットをセットアップします。これには数分要します。
chatbot = LLaVAChatBot(load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type='nf8')
モデルとのチャット
LLaVAは日本語でも結構動作します。
ans = chatbot.start_new_chat(img_path="https://images.unsplash.com/photo-1686577353812-6cbc7fce384b?q=80&w=1887&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
prompt="画像にある魚が水に投げ込まれたら浮きますか?ステップバイステップで考えてください。")
display(chatbot.conv_img)
print(ans)
確かに、魚が水に投げ込まれると、その体重によって水面に浮き上がります。しかし、この図像では、魚が床に倒れているため、水に浮き上がることはありません。魚が水に投げ込まれると、その体重によって水面に浮き上がりますが、この図像ではそのような状況が起こらないため、答えは「いいえ」です。
若干?ですが、回答が得られています。
ans = chatbot.continue_chat("水に浮く魚の彫刻を作るにはどのような素材を使ったらいいですか?")
水に浮く魚の彫刻を作るには、軽量で水に抗する性質がある素材が必要です。一般的に、軽量で水に抗する素材として、プラスチックやポリエステル製の素材が使われます。これらの素材は、軽量でありながらも強度があり、水に浮くことができます。また、これらの素材は易丈夫で製作が容易であるため、製作コストも低く抑えられます。
こちらは正しいですね。
すべての会話の取得
print(chatbot.get_conv_text())
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <im_start><image><im_end>
画像にある魚が水に投げ込まれたら浮きますか?ステップバイステップで考えてください。 ASSISTANT: 確かに、魚が水に投げ込まれると、その体重によって水面に浮き上がります。しかし、この図像では、魚が床に倒れているため、水に浮き上がることはありません。魚が水に投げ込まれると、その体重によって水面に浮き上がりますが、この図像ではそのような状況が起こらないため、答えは「いいえ」です。</s></s>USER: 水に浮く魚の彫刻を作るにはどのような素材を使ったらいいですか? ASSISTANT: 水に浮く魚の彫刻を作るには、軽量で水に抗する性質がある素材が必要です。一般的に、軽量で水に抗する素材として、プラスチックやポリエステル製の素材が使われます。これらの素材は、軽量でありながらも強度があり、水に浮くことができます。また、これらの素材は易丈夫で製作が容易であるため、製作コストも低く抑えられます。</s></s>