1
0

N番煎じでGoogleのGemmaをDatabricksで動かす

Posted at

導入

GoogleからローカルLLMであるGemmaが公開されました。
Geminiと同種の技術を活用した軽量版LLMとのこと。
サイズは2Bと7Bが公開されています。

GoogleがオープンなLLMを出したということで、既に様々な方が試されています。

が、気にせずこちらもDatabricksで試してみます。それがN番煎じシリーズ。

検証はDatabricks on AWS上で実施しました。
DBRは14.3ML、クラスタタイプはg5.xlarge(GPUクラスタ)を利用しています。

Step1. パッケージインストール

最新のtransformersaccelerateパッケージ等をインストールします。
transformersのバージョンは4.38以降が必要です。

%pip install -U transformers accelerate bitsandbytes
dbutils.library.restartPython()

Step2. モデルのダウンロード

モデルをhuggingfaceからダウンロードし、Unity Catalog Volumesへ保管します。
今回はInstruct tuningされた7Bモデルを試してみました。

事前にhuggingface上でGemmaのライセンス認証をしておく必要があります。
また、huggingfaceのAccess tokenをDatabricks Secretsに登録しておいてください。
以下のコードでは、huggingfaceスコープにaccess_tokenという名前のキーからAccess Tokenを取得しています。

from typing import Optional

def download_model(model_id:str, revision:Optional[str]=None):
    import os
    from huggingface_hub import snapshot_download

    UC_VOLUME = "/Volumes/training/llm/model_snapshots"
    access_token = dbutils.secrets.get("huggingface", "access_token")

    rev_dir = ("--" + revision) if revision else ""
    local_dir = f"/tmp/{model_id}{rev_dir}"
    uc_dir = f"/models--{model_id.replace('/', '--')}"
    
    snapshot_location = snapshot_download(
        repo_id=model_id,
        revision=revision,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
        token=access_token,
    )

    dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}{rev_dir}", recurse=True)

model_id = "google/gemma-7b-it"
download_model(model_id)

Step3. モデルのロード

transformers上でモデルをロードします。
BF16の精度でロードしました。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_path = "/Volumes/training/llm/model_snapshots/models--google--gemma-7b-it"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Step4. 推論実行

推論用の関数を定義。

def generate_text(instruct:str):
    chat = [
        {"role": "user", "content": instruct},
    ]
    prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

    inputs = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
    outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=150)

    return tokenizer.decode(outputs[0])

では、いくつか推論を実行してみます。


まずはシンプルに日本語で質問。

print(generate_text("Databricksとは何?"))
出力
<bos><start_of_turn>user
Databricksとは何?<end_of_turn>
<start_of_turn>model
Databricksは、ビッグデータ分析のためのオープンソースプラットフォームです。これは、データ分析の作業をより効果的に実行するためのツールを提供します。

**主な特徴:**

* **オープンソース:** Databricksはオープンソースプラットフォームであり、無料で使用できます。
* **データ統合:** Databricksは、さまざまなデータソースからデータを統合し、分析するための統合されたデータワークスペースを提供します。
* **データ分析:** Databricksは、ビッグデータ分析のための強力な分析機能を提供します。
* **データビジュアル:** Databricksは、データ分析の結果に基づいてビジュアルを生成するためのツールを提供します。
* **スケーラビリティ:** Databricksは、ビッグ

Markdown形式で適切に出力されていますね。

同様の質問ももう一つ。

print(generate_text("東京の観光名所を3か所、特徴と共に教えて"))
出力
<bos><start_of_turn>user
東京の観光名所を3か所、特徴と共に教えて<end_of_turn>
<start_of_turn>model
**1. 東部銀座**

* **特徴:** 古くからの繁華街、伝統的な建築物、現代の建築物、そして多くの観光スポットがある。
* **アクセス:** 東部線、銀座線、中央線を利用してアクセスできます。

**2. 東京スカイツリー**

* **特徴:** 高さから見た東京の景色を望むことができる高い建築物。
* **アクセス:** 東部線を利用してアクセスできます。

**3. 神 Temple**

* **特徴:** 古くからの寺、伝統的な建築物、そして多くの観光スポットがある。
* **アクセス:** 東部線を利用してアクセスできます。<eos>

東部線とは。。。


JSON形式出力を単純にできるのか?

print(generate_text("カレーの材料をJSON形式で出力して。結果のみ出力してください。"))
出力
<bos><start_of_turn>user
カレーの材料をJSON形式で出力して。結果のみ出力してください。<end_of_turn>
<start_of_turn>model
```json
{
  "ingredients": [
    "野菜",
    "カレー粉",
    "こしょう",
    "塩",
    "こぼこ",
    "牛乳",
    "ポテト",
    "たまご"
  ]
}
```<eos>

内容はともかく、JSON形式出力はできていますね。

まとめ

簡易ですが、Gemmaを試してみました。

同じ7BモデルということでMistral-7Bとの比較評価をRedditなどでも見かけます。
Gemmaが明らかに優れているわけではなさそうなのですが、多言語性能はGemmaが良さそうですね。

GoogleもオープンなLLM領域に参入してきてますます市場の活性化が進みそうです!

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0