2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Stability AIのJapanese StableLM Instruct Alpha 7BをDatabricksで動かしてみる

Posted at

どんどん日本語対応LLMが出てきてすごいくらいしか感想が出てきません。

ベースモデルは商用利用可能で、指示応答言語モデルは研究用途で使用可能ですね。

Japanese StableLM Base Alpha 7Bは商用利用可能なApache License 2.0での公開となります。Japanese StableLM Instruct Alpha 7Bは研究目的で作成されたモデルであり、研究目的での利用に限定した公開となります。詳細は Hugging Face Hub のページをご確認ください。

リポジトリへのアクセスをリクエストします。
Screenshot 2023-08-11 at 7.27.25.png

ライブラリをインストールします。いつもと同じようにGPUクラスター使ってます。

%pip install sentencepiece einops
dbutils.library.restartPython()

Huggingfaceにログインします。自分のアクセストークンを指定します。

from huggingface_hub import notebook_login

# Login to Huggingface to get access to the model
notebook_login()

モデルのダウンロードとプロンプト生成の関数。

import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM

tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'])

model = AutoModelForCausalLM.from_pretrained(
    "stabilityai/japanese-stablelm-instruct-alpha-7b",    
    trust_remote_code=True,
)
model.half()
model.eval()

if torch.cuda.is_available():
    model = model.to("cuda")

def build_prompt(user_query, inputs="", sep="\n\n### "):
    sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
    p = sys_msg
    roles = ["指示", "応答"]
    msgs = [": \n" + user_query, ": "]
    if inputs:
        roles.insert(1, "入力")
        msgs.insert(1, ": \n" + inputs)
    for role, msg in zip(roles, msgs):
        p += sep + role + msg
    return p
# this is for reproducibility.
# feel free to change to get different result
seed = 42
torch.manual_seed(seed)

# Infer with prompt without any additional input
user_inputs = {
    "user_query": "VR とはどのようなものですか?",
    "inputs": ""
}
prompt = build_prompt(**user_inputs)

input_ids = tokenizer.encode(
    prompt, 
    add_special_tokens=False, 
    return_tensors="pt"
)

tokens = model.generate(
    input_ids.to(device=model.device),
    max_new_tokens=256,
    temperature=1,
    top_p=0.95,
    do_sample=True,
)

out = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
print(out)

バーチャルリアリティは、現実の世界のように見える仮想世界の 3D 仮想現実のシミュレーションです。これは、ヘッドセットを介して、ユーザーが見たり、聞いたり、体験できるものです。

すごい(語彙力)。

関数化します。

def generate_test(input):
  # Infer with prompt without any additional input
  user_inputs = {
    "user_query": input,
    "inputs": ""
  }
  prompt = build_prompt(**user_inputs)

  input_ids = tokenizer.encode(
    prompt, 
    add_special_tokens=False, 
    return_tensors="pt"
  )

  tokens = model.generate(
    input_ids.to(device=model.device),
    max_new_tokens=256,
    temperature=1,
    top_p=0.95,
    do_sample=True,
  )

  out = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
  return out

お約束の質問。

generate_test("Databricksとは?")

Out[6]: 'Databricksは、データ・分析・AIプラットフォームです。ビジネス・組織がデータの価値を解き放つために、データの分析、探索、利用、管理、管理を可能にするための包括的な基盤を提供します。'

合ってます。

generate_test("Delta Lakeとは?")

Out[7]: 'Delta Lakeは、Amazon Web Services(AWS)によりホストされた、データのプライバシー、セキュリティ、およびコスト効率を強化するために設計されたデータレイクの一種です。\nデータサイエンティストやデータエンジニアは、Delta Lakeを使用してデータレイクにデータを統合し、他のユーザーのニーズに合わせてデータを整理し、さまざまなユースケースに活用します。'

Delta Lakeも知ってる…。

generate_test("MLOpsとは?")

MLOpsまで...。

Out[8]: 'MLOpsは、自動化と継続的な学習を可能にするMLのソフトウェア開発と運用モデルを指します。MLOpsは、機械学習パイプラインが機能的に自動化されていること、MLトレーニングとデプロイメントがスムーズかつ効率的に展開されること、MLの予測能力と品質を確保するためのプロセスとテクノロジーがあることを確実にすることを含みます。'

今度はベースモデルも試してみます。

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?