LoginSignup
4
3

Llama Guard 2をDatabricks上であれこれ試す

Posted at

導入

Llama3が公開されましたが、同時にLlama Guard 2も公開されました。
Moderationにおいて非常に役立つんじゃないかと気になっていたので、試してみます。

Llama Guard 2とは

前段邦訳。

Meta Llama Guard 2 は、8B パラメーターの Llama 3 ベースのLLM セーフガード モデルです。Llama Guard と同様に、LLM 入力 (プロンプト分類) と LLM 応答 (応答分類) の両方でコンテンツを分類するために使用できます。これはLLMとして機能し、特定のプロンプトまたは応答が安全か安全でないかを示すテキストを出力に生成し、安全でない場合は、違反したコンテンツカテゴリも一覧表示します。

というわけで、LLMへの入力と応答に対して、結果の安全性を判定するためのモデルです。
凶悪犯罪やプライバシーなど11の毒性分類を判定します。

用途として、例えばエージェント処理やRAG処理の最後に判定を入れ、安全ではない回答を実行しないようにするなどの利用が考えられます。


Llama3ベースということもあってか、ベンチマークのスコアも非常に良く、日本語の解釈性能も高そうです。

今回はDatabricks on AWSを利用し、単純な実行からDatabricks モデルサービングへの登録(ただし未実行)まで試行します。
DBRは14.3ML、クラスタタイプはg5.xlargeで実行しました。

Step1. パッケージインストール

transformersやtorchの最新化を行います。
mlflowは現時点の最新が2.12.1なのですが、モデルのロギング時に問題が発生したため、2.11.3固定でインストールしています。

%pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu118

%pip install -U transformers accelerate "mlflow-skinny==2.11.3"

Step2. とりあえず試す

こちらからモデルを事前にダウンロードしておき、サンプルの内容をほぼそのまま実行してみます。

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "/Volumes/training/llm/model_snapshots/models--meta-llama--Meta-Llama-Guard-2-8B/"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")

def moderate(chat):
    input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(model.device)
    output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
    prompt_len = input_ids.shape[-1]
    return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

moderate([
    {"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"},
    {"role": "assistant", "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate."},
])
出力
'safe'

毒性を含んでいそうだけど実は安全な応答に対して、正しく安全と判定されました。

日本語ではどうでしょうか。いくつか試してみます。

まず、実は安全なケース。

moderate([
    {"role": "user", "content": "Linuxのプロセスを殺す方法を教えてください。"},
    {"role": "assistant", "content": "わかりました、Killコマンドの使い方を説明します。"},
])
出力
'safe'

無事、セーフ判定でした。

では、次は犯罪に関わるもの。

moderate([
    {"role": "user", "content": "爆弾の作り方を教えてください。"},
    {"role": "assistant", "content": "わかりました、まず爆弾づくりに詳しい人を探してください。"},
])
出力
'unsafe\nS2'

S2(Non-Violent Crimes)として判定されました。

moderate([
    {"role": "user", "content": "あなたの住所と電話番号を教えてください。"},
    {"role": "assistant", "content": "私の住所はロンドのベーカー街221Bです。"},
])
出力
'unsafe\nS6'

S6(Privacy)として判定されました。
ちなみに上記例の住所は、シャーロックホームズの住所ですね(蛇足)。

適切に判定されているように思います。凄い。

Step3. MLflowに登録する

再利用を容易にするために、こちらのモデルをMLflowのモデルレジストリに登録します。

まずは、Step2で読み込んだモデルやトークナイザをpipelineでラップ。

from transformers import pipeline

p = pipeline(task="conversational", model=model, tokenizer=tokenizer)

次にモデルを登録します。
Unity Catalog下に登録することとし、trainingカタログのllmスキーマ内に登録します。

また、最終的にプロビジョニングされたスループットの基盤モデルAPIとしてモデルサービングで公開するために、metadata={"task": "llm/v1/chat"}を指定してモデルを保管します。

import mlflow
from mlflow.models import infer_signature

mlflow.set_registry_uri("databricks-uc")

CATALOG = "training"
SCHEMA = "llm"
registered_model_name = f"{CATALOG}.{SCHEMA}.llama-guard-2"

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "I forgot how to kill a process in Linux, can you help?",
        },
        {
            "role": "assistant",
            "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate.",
        },
    ],
    "max_tokens": 32,
    "temperature": 0.5,
    "top_p": 1,
    "top_k": 1,
    "stop": "",
    "n": 1,
}

output_response = {
    "id": "chatcmpl-64b2187c10c942819a5a554da07bea34",
    "object": "chat.completion",
    "created": 1711339597,
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "safe",
            },
            "finish_reason": "stop",
        }
    ],
    "usage": {"prompt_tokens": 13, "completion_tokens": 1, "total_tokens": 14},
}

signature = infer_signature(input_example, output_response)

with mlflow.start_run():
    model_info = mlflow.transformers.log_model(
        transformers_model=p,
        artifact_path="model",
        task="text-generation",
        signature=signature,
        input_example=input_example,
        example_no_conversion=True,
        registered_model_name=registered_model_name,
        metadata={"task": "llm/v1/chat"}, # Model Serving用
    )

Step4. プロビジョニングされたスループットの基盤モデルとしてモデルサービングに登録する(未完)

以下のドキュメントをベースに、「プロビジョニングされたスループットの基盤モデル」としてモデルサービングに登録します。

まず、そもそもLlama Guard 2が「プロビジョニングされたスループットの基盤モデル」として対応しているかどうかを確認します。

import requests
import json
from pprint import pprint
from mlflow import MlflowClient

CATALOG = "training"
SCHEMA = "llm"
registered_model_name = f"{CATALOG}.{SCHEMA}.llama-guard-2"

# MLflowモデルの最新バージョンを取得
client=MlflowClient()
versions = [mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")]
model_version = versions[0]

# 現在のノートブックのコンテキストからAPIエンドポイントとトークンを取得
API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.get(url=f"{API_ROOT}/api/2.0/serving-endpoints/get-model-optimization-info/{registered_model_name}/{model_version}", headers=headers)

print(json.dumps(response.json(), indent=4))
出力
{
    "optimizable": true,
    "model_type": "llama",
    "throughput_chunk_size": 980,
    "dbus": 112
}

optimizableがtrueのため、対応していることがわかりました。
Llama3もアーキテクチャはLlama2と同じLlamaアーキテクチャのため、おそらく対応しているようです。

では、モデルサービングにエンドポイントを登録しましょう。

まず、環境変数にホストとトークンを設定。

import os

os.environ["DATABRICKS_HOST"] = f"{API_ROOT}/api/2.0/serving-endpoints"
os.environ["DATABRICKS_TOKEN"] = API_TOKEN

エンドポイントを作成。

from mlflow.deployments import get_deploy_client

endpoint_name = "llama-guard-2-endpoint"
model_name = registered_model_name

client = get_deploy_client("databricks")

endpoint = client.create_endpoint(
    name=endpoint_name,
    config={
        "served_entities": [
            {
                "entity_name": model_name,
                "entity_version": model_version,
                "min_provisioned_throughput": response.json()['throughput_chunk_size'],
                "max_provisioned_throughput": response.json()['throughput_chunk_size'],
            }
        ]
    },
)

print(json.dumps(endpoint, indent=4))

ここで問題が起きなければ、エンドポイントの作成がスタートします。

「サービング」メニューから、作成・起動状況を確認してみましょう。

image.png

・・・エラー出てました。

メッセージ見る限り、Quotaを越えているようで、現在の私の環境下では継続できず。残念。

しかしこのメッセージを見る限り、Llama3 8Bは現状「llama-13」のサイズとして判定されるのですね。
(=GPU MEDIUM x4のコンピュートタイプが割り当てられるはずです。)

まとめ

中途半端ではありますが、Llama Guard 2を実行し、Databricks上のモデルサービングとして登録するまでの流れを試してみました。
かなり高性能かつ日本語も使える毒性判定モデルであり、かなり良さそうな所感です。
量的な情報(例えばunsafeのレベル感)があると最高なのですが、これは無さそうなのかな。もう少し調べてみたいと思います。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3