LoginSignup
1
1

SGLangでLlava v1.6を利用する on Databricks

Posted at

実はLlava v1.5を動かす記事を書いていたら、丁度1.6が出たという。。。

導入

こちらのさらに続きです。

SGLangは純粋なLLMだけではなく、画像のマルチモーダルモデルであるLlavaにも対応しています。

※ 手前みそですが、以前以下の記事で試していました。

また、Llavaのオンラインデモは既にSGlangで動作しているようです。

丁度、Llavaの最新Versionであるv1.6が出ましたので、SGLangで実行してみます。

検証はDatabricks on AWS上で実施しました。
DBRは14.2ML、クラスタタイプはGPU(g5.xlarge)を使用しています。

Step1. パッケージインストール

前回同様、必要なパッケージをインストールします。

まずはtorchやvllmをインストール。

# torch, xformers
# pytorchのリポジトリから直接インストールする場合
# %pip install -U https://download.pytorch.org/whl/cu118/torch-2.1.2%2Bcu118-cp310-cp310-linux_x86_64.whl
# %pip install -U https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl

%pip install -U /Volumes/training/llm/tmp/torch-2.1.2+cu118-cp310-cp310-linux_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/xformers-0.0.23.post1+cu118-cp310-cp310-manylinux2014_x86_64.whl

# vLLM
# GithubからvLLMを直接インストールする場合
# %pip install https://github.com/vllm-project/vllm/releases/download/v0.2.7/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl
%pip install /Volumes/training/llm/tmp/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl

最新のSGLangを利用するため、ソースからSGLangをインストール。

今回は、Databricks Reposを使ってSGLangのリポジトリをクローンしました。
Reposを使わない場合、git cloneを実行してリポジトリをクローンしてください。

image.png

その後、クローンしたリポジトリを利用し、ソースからインストール。
tritonも2.2.0以降をインストールします。

!cd /Workspace/Repos/install_path/sglang && pip install -e "python[srt]"

%pip install -U "triton>=2.2.0"

dbutils.library.restartPython()

Step2. torch設定

前回同様、torchのmultiprocessingの処理種別を変更。

import torch
torch.multiprocessing.set_start_method('spawn', force=True)

Step3. ランタイム起動

Llavaモデルをダウンロードして、SGLangのランタイムを起動します。
今後の再利用性を考慮して、一旦モデルのスナップショットをUnity Catalog Volumes上に保管してから読み込みます。

なお、Llava v1.6はベースモデル違いで数種類公開されていますが、今回は以下のVicunaベースのモデルを利用しました。

from typing import Optional

def download_model(model_id:str, revision:Optional[str]=None):
    import os
    from huggingface_hub import snapshot_download

    UC_VOLUME = "/Volumes/training/llm/model_snapshots"
    access_token = dbutils.secrets.get("huggingface", "access_token")

    rev_dir = ("--" + revision) if revision else ""
    local_dir = f"/tmp/{model_id}{rev_dir}"
    uc_dir = f"/models--{model_id.replace('/', '--')}"
    
    snapshot_location = snapshot_download(
        repo_id=model_id,
        revision=revision,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
        token=access_token,
    )

    dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}{rev_dir}", recurse=True)

model_id = "liuhaotian/llava-v1.6-vicuna-7b"
download_model(model_id)    
import sglang as sgl

model_path = "/Volumes/training/llm/model_snapshots/models--liuhaotian--llava-v1.6-vicuna-7b"
runtime = sgl.Runtime(model_path, tokenizer_path="llava-hf/llava-1.5-7b-hf")
sgl.set_default_backend(runtime)

Tokenizerだけllava-hf/llava-1.5-7b-hfを使用しています。
(こちらもダウンロードしておいたほうが再利用性は高い)
これはSGLangの問題だと思うのですが、現時点でllava-v1.6-vicuna-7bのTokenizerそのままを使おうとするとエラーでランタイムを起動できませんでした。

Step4. 画像に対する質問

では、適当な画像を使って、画像に対する問い合わせをしてみます。

画像はいらすとや様の以下画像を利用させていただきました。

import requests
from PIL import Image

@sgl.function
def image_qa(s, image_path, question):
    s += sgl.user(sgl.image(image_path) + question)
    s += sgl.assistant(sgl.gen("answer"))

def save_image_from_url(url, save_path):
    response = requests.get(url, stream=True)
    
    if response.status_code == 200:
        with open(save_path, 'wb') as file:
            file.write(response.content)    


def show_image(image_path: str):

    # 画像の読み込み
    image = Image.open(image_path)
    
    # 画像のサイズを半分にする
    size = image.size
    new_size = (size[0] // 2, size[1] // 2)
    image = image.resize(new_size)

    # 画像の表示
    display(image)


def single():

    image_url = "https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEha5wpu5v4qbr6eK6hypAd3LiuXjCaO_JAU3Ts7g49lhqTLBXD0l6d7n1sXgicDD5-v7eIx3owuJ0ahyKcnR9SfM878pL-Uw5QLCm5nkqgCPHgExhbqfzl00JU4Z7EoT-gd5Oo1c8v-ujAhaF3jZvOZtdUG8hSLXOnwpn8HI4cGk56HKpt2vvU9gcCOkfxX/s929/eto_tatsu_banzai.png"
    save_path = "/tmp/test.png"
    save_image_from_url(image_url, save_path)

    show_image(save_path)

    question = "これは何?"
    state = image_qa.run(
        image_path=save_path,
        question=question,
        max_new_tokens=128,
        temperature=0,
    )
    print("Q:", question)
    print("A:", state["answer"])


single()

結果は以下のような感じ。

image.png

竜だと判別できるんだ。

まとめ

SGLangを使ってLlava 1.6を動かしてみました。
Llavaはtransformers等多くのフレームワークでサポートされており、かなり容易に試せるようになってはいますが、SGLangでもかなり容易に利用することができます。

ただ、まだMistralベースのLlava 1.6は動作しませんでした(やり方が間違っているのかは不明。。。)
SGLangはまだまだ発展途上ですが、急速にメンテが進んでおり、今後のさらなる進化が楽しみです。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1