マルチモーダルすごい。
LLaVA: Large Language and Vision Assistant
Databricksでの動作確認
先人達の肩に乗ります。
結構大きめのGPUクラスターを準備しました。そうでないとGPUに収まらないエラーになりました。
ERROR: Project xxxx has a 'pyproject.toml' and its build backend is missing the 'build_editable' hook. Since it does not have a 'setup.py' nor a 'setup.cfg', it cannot be installed in editable mode. Consider using a build backend that supports PEP 660.
を回避するためにtouch setup.cfg
を実行しています。こちらを参考にしました。
%sh
git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
touch setup.cfg
pip install -U -qq transformers accelerate
pip install -e .
Pythonカーネルをリスタート。
dbutils.library.restartPython()
モデルのダウンロード。
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
model_path = "liuhaotian/llava-v1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path)
)
推論します。
model_path = "liuhaotian/llava-v1.5-7b"
prompt = "What are the things I should be cautious about when I visit here?"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
args = type('Args', (), {
"model_path": model_path,
"model_base": None,
"model_name": get_model_name_from_path(model_path),
"query": prompt,
"conv_mode": None,
"image_file": image_file,
"sep": ",",
"temperature": 0.2,
"top_p": 0.8,
"num_beams": 1,
"min_new_tokens": None,
"max_new_tokens": None
})()
displayHTML(f"<img src=\"{image_file}\">")
eval_model(args)
日本語ではどうでしょう。
prompt = "この場所を訪れた際に注意すべきことはなんですか?"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
args = type('Args', (), {
"model_path": model_path,
"model_base": None,
"model_name": get_model_name_from_path(model_path),
"query": prompt,
"conv_mode": None,
"image_file": image_file,
"sep": ",",
"temperature": 0.2,
"top_p": 0.8,
"num_beams": 1,
"min_new_tokens": None,
"max_new_tokens": None
})()
displayHTML(f"<img src=\"{image_file}\">")
eval_model(args)
関数化します。
def evaluate_image(image_file, prompt):
args = type('Args', (), {
"model_path": model_path,
"model_base": None,
"model_name": get_model_name_from_path(model_path),
"query": prompt,
"conv_mode": None,
"image_file": image_file,
"sep": ",",
"temperature": 0.2,
"top_p": 0.8,
"num_beams": 1,
"min_new_tokens": None,
"max_new_tokens": None
})()
displayHTML(f"<img src=\"{image_file}\">")
eval_model(args)
こちらの画像を試します。
evaluate_image("https://1.bp.blogspot.com/-pzkUACogq0E/X5OcHr5ZnSI/AAAAAAABb5Q/xb-j2PQXgu03_vypUL1XNOYv4bhpWEFgQCNcBGAsYHQ/s400/bird_mameruriha_inko_blue.png", "何が描かれていますか?")