LoginSignup
4
1

最初はなんのこっちゃと思いました。モデルサービング以外の選択肢があったとは。

早速やってみます。

注意
元記事にもあるようにインタラクティブな開発での使用が推奨されています。プロダクションでの用途ではモデルサービングを使ってください。

プロキシーアプリの起動

ノートブックをクラスターにアタッチして、以下を実行します。

from flask import Flask, request, jsonify
import torch
from transformers import pipeline, AutoTokenizer, StoppingCriteria

model = "databricks/dolly-v2-3b"
tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left")
dolly = pipeline(model=model, tokenizer=tokenizer, trust_remote_code=True, device_map="auto")
device = dolly.device

class CheckStop(StoppingCriteria):
    def __init__(self, stop=None):
        super().__init__()
        self.stop = stop or []
        self.matched = ""
        self.stop_ids = [tokenizer.encode(s, return_tensors='pt').to(device) for s in self.stop]
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
        for i, s in enumerate(self.stop_ids):
            if torch.all((s == input_ids[0][-s.shape[1]:])).item():
                self.matched = self.stop[i]
                return True
        return False

def llm(prompt, stop=None, **kwargs):
  check_stop = CheckStop(stop)
  result = dolly(prompt, stopping_criteria=[check_stop], **kwargs)
  return result[0]["generated_text"].rstrip(check_stop.matched)

app = Flask("dolly")

@app.route('/', methods=['POST'])
def serve_llm():
  resp = llm(**request.json)
  return jsonify(resp)

app.run(host="0.0.0.0", port="7777")

これだけで、Dollyがサービングされるようになります。

 * Serving Flask app "dolly" (lazy loading)
 * Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.
 * Debug mode: off
 * Running on all addresses.
   WARNING: This is a development server. Do not use it in a production deployment.
 * Running on http://10.0.0.69:7777/ (Press CTRL+C to quit)
10.0.14.228 - - [18/Jun/2023 06:21:20] "POST / HTTP/1.1" 200 -

LLMへのアクセス

同じクラスターに別のノートブックをアタッチして以下を実行します。

最新のlangchainをインストールします。

%pip install langchain==0.0.202
dbutils.library.restartPython()

これだけでLLM(Dolly)にアクセスできます。

from langchain.llms import Databricks

# アプリを稼働しているのと同じクラスターにアタッチしたDatabricksノートブックを実行する場合、
# `Databricks`インスタンスの作成に必要なのはドライバーのポート番号を指定することだけです。
llm = Databricks(cluster_driver_port="7777")

llm("How are you?")
Out[1]: "I'm doing well, thank you."

動きました!

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1