1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

N番煎じでStockmark社のstockmark-13bをDatabricksで動かす

Posted at

最近N番煎じシリーズばかりですが、次々出て来てるのでしょうがない(言い訳)。

導入

Stockmark社が新たに日本語LLMを公開しました。

プレスリリース中にもあるように、

本モデルの特徴は、ビジネス用途での信頼性と速度です。

ということで、最近までのビジネス知識を保有しているようです。
また、MITライセンスであり、商用等の利用に対しても使いやすいライセンス体系となっています。

現状、インストラクトモデルは出されていないようです。

既に下記記事のように多くの方が試されていますが、N番煎じシリーズなので気にせずに試してみます。

検証環境

Databricks on AWSを使って検証しています。
DBRは14.0ML、g5.2xlargeのクラスタを利用して推論します。

Step1. モデルの取得と保管

こちらとほぼ同じです。
最終的にUnity CatalogのVolumesに保管します。

%pip install -U -qq transformers accelerate

dbutils.library.restartPython()
import os
from huggingface_hub import snapshot_download

UC_VOLUME = "/Volumes/モデルの保管場所"

model = "stockmark/stockmark-13b"
local_dir = f"/tmp/{model}"
uc_dir = "/models--stockmark--stockmark-13b"

snapshot_location = snapshot_download(
    repo_id=model,
    local_dir=local_dir,
    local_dir_use_symlinks=False,
)

dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}", recurse=True)

Step2. 推論準備

モデルを読み込んで、推論用の関数を定義します。

%pip install -U -qq transformers accelerate bitsandbytes 
dbutils.library.restartPython()
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

UC_VOLUME = "/Volumes/モデルの保管場所"

tokenizer_path = f"{UC_VOLUME}/models--stockmark--stockmark-13b"
model_path = f"{UC_VOLUME}/models--stockmark--stockmark-13b"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map={"": 0}, 
    load_in_8bit=True,    
)

def generate_batch(prompt: str):

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            max_new_tokens=128,
            do_sample=True,
            temperature=0.1
        )
        
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output

Step3. 推論

いくつかやってみます。

単純生成1

print(
    generate_batch(
        "2022年にイーロンマスクによって買収されたSNSサービスは、",
    )
)
出力
2022年にイーロンマスクによって買収されたSNSサービスは、TwitterからXに名称変更されました。
Twitterは、2022年10月27日に、Twitter Blueという有料サブスクリプションサービスを開始しました。
Twitter Blueは、Twitterの有料サブスクリプションサービスです。
Twitter Blueは、月額980円(税込)で、Twitterをより快適に利用するためのさまざまな特典を利用できます。
Twitter Blueは、Twitterをより快適に利用するためのさまざまな特典を利用できます。
Twitter Blueは、Twitterをより快適に利用するためのさまざまな特典を利用できます。Twitter Blueは、Twitterをより快適に利用するためのさまざまな特典を利用できます。
Twitter Blueは、Twitterをより快適に利用するための

単純生成2

print(
    generate_batch(
        "宮崎駿の代表作は、",
    )
)
出力
宮崎駿の代表作は、1984年に公開された「風の谷のナウシカ」です。
「風の谷のナウシカ」は、宮崎駿が1982年にアニメ雑誌「アニメージュ」に連載していた漫画が原作です。
宮崎駿は、1984年に「風の谷のナウシカ」をアニメ映画化しました。
「風の谷のナウシカ」は、宮崎駿の代表作であり、スタジオジブリの代表作でもあります。
「風の谷のナウシカ」は、宮崎駿が1982年にアニメ雑誌「アニメージュ」に連載していた漫画が原作です。宮崎駿は、1984

コード生成

print(
    generate_batch(
        "ランダムな10個の要素を作成し、ランダムに並び替えてるプログラムコードをPythonで作成すると、",
    )
)
出力
ランダムな10個の要素を作成し、ランダムに並び替えてるプログラムコードをPythonで作成すると、以下のようになります。
# ランダムな10個の要素を作成し、ランダムに並び替える
# ランダムな10個の要素を作成し、ランダムに並び替える(10個ずつ)
# ランダムな10個の要素を作成し、ランダムに並び替える(100個ずつ)
# ランダムな10個の要素を作成し、ランダムに並び替える(1000個ずつ)
# ランダムな10個の要素を作成し、ランダムに並び替える(10000個ずつ)
# ランダムな10個の要素を作成し、ランダムに並び替える(1000

表作成指示

print(
    generate_batch(
        "東京、名古屋、大阪、京都、福岡の「都道府県」「おすすめスポット」「おすすめグルメ」をMarkdownの表形式で出力すると、",
    )
)
出力
東京、名古屋、大阪、京都、福岡の「都道府県」「おすすめスポット」「おすすめグルメ」をMarkdownの表形式で出力すると、こんな感じになります。
東京、名古屋、大阪、京都、福岡の「都道府県」「おすすめスポット」「おすすめグルメ」をMarkdownの表形式で出力すると、こんな感じになります。
東京、名古屋、大阪、京都、福岡の「都道府県」「おすすめスポット」「おすすめグルメ」をMarkdownの表形式で出力すると、こんな感じになります。
東京、名古屋、大阪、京都、福岡の「都道府県」「おすすめスポット」「おすすめグルメ」をMarkdownの表形式で出力すると、こんな感じになります。
東京、名古屋、大阪、京都、福岡の「都道府県」「おすすめスポット」「おすすめグルメ」をMarkdownの表

おまけ

print(
    generate_batch(
        "まどか☆マギカで一番かわいいのは、",
    )
)
出力
まどか☆マギカで一番かわいいのは、ほむらちゃんです。
ほむらちゃんかわいいよほむらちゃん。
ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。
ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃんかわいいよほむらちゃん。ほむらちゃん

こわい。。。

おまけのおまけ

Llama2ベースなので、CTranslate2で変換・推論可能です。

8bit量子化で変換したモデルでも推論してみました。

import ctranslate2
import transformers
import torch
from ct2chat import ChatCTranslate2

model_path = "/Volumes/CT2モデルの保管場所/stockmark/stockmark-13b"

# ジェネレーターとトークナイザーの準備
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = ctranslate2.Generator(model_path, device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)

llm = ChatCTranslate2(
    generator=generator,
    tokenizer=tokenizer,
    max_length=256,
    human_message_template="{}",
    ai_message_template="{}",
    prompt_line_separator="",
)
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)

output = llm(
    [
        HumanMessage(content="魔法少女まどか☆マギカで一番可愛いのキャラは、"),
    ]
)
print(output.content)
出力
やっぱり暁美ほむらちゃんだよね。
魔法少女まどか☆マギカで一番可愛いのキャラは、やっぱり暁美ほむらちゃんだよね。 へのコメントはまだありません
魔法少女まどか☆マギカで一番可愛いのキャラは、やっぱり暁美ほむらちゃんだよね。
魔法少女まどか☆マギカで一番可愛いのキャラは、やっぱり暁美ほむらちゃんだよね。 へのコメントはまだありません への2件のコメント

まとめ

8bit量子化した状態で試したため、量子化しない場合とは結果が変わるかもしれません。
ベースがLlama2 13Bなので、基本性能はそれなりに高いのではないかと思います。
インストラクトチューニングされたモデルを期待しています。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?