最初はなんのこっちゃと思いました。モデルサービング以外の選択肢があったとは。
早速やってみます。
注意
元記事にもあるようにインタラクティブな開発での使用が推奨されています。プロダクションでの用途ではモデルサービングを使ってください。
プロキシーアプリの起動
ノートブックをクラスターにアタッチして、以下を実行します。
from flask import Flask, request, jsonify
import torch
from transformers import pipeline, AutoTokenizer, StoppingCriteria
model = "databricks/dolly-v2-3b"
tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left")
dolly = pipeline(model=model, tokenizer=tokenizer, trust_remote_code=True, device_map="auto")
device = dolly.device
class CheckStop(StoppingCriteria):
def __init__(self, stop=None):
super().__init__()
self.stop = stop or []
self.matched = ""
self.stop_ids = [tokenizer.encode(s, return_tensors='pt').to(device) for s in self.stop]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
for i, s in enumerate(self.stop_ids):
if torch.all((s == input_ids[0][-s.shape[1]:])).item():
self.matched = self.stop[i]
return True
return False
def llm(prompt, stop=None, **kwargs):
check_stop = CheckStop(stop)
result = dolly(prompt, stopping_criteria=[check_stop], **kwargs)
return result[0]["generated_text"].rstrip(check_stop.matched)
app = Flask("dolly")
@app.route('/', methods=['POST'])
def serve_llm():
resp = llm(**request.json)
return jsonify(resp)
app.run(host="0.0.0.0", port="7777")
これだけで、Dollyがサービングされるようになります。
* Serving Flask app "dolly" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment.
Use a production WSGI server instead.
* Debug mode: off
* Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
* Running on http://10.0.0.69:7777/ (Press CTRL+C to quit)
10.0.14.228 - - [18/Jun/2023 06:21:20] "POST / HTTP/1.1" 200 -
LLMへのアクセス
同じクラスターに別のノートブックをアタッチして以下を実行します。
最新のlangchainをインストールします。
%pip install langchain==0.0.202
dbutils.library.restartPython()
これだけでLLM(Dolly)にアクセスできます。
from langchain.llms import Databricks
# アプリを稼働しているのと同じクラスターにアタッチしたDatabricksノートブックを実行する場合、
# `Databricks`インスタンスの作成に必要なのはドライバーのポート番号を指定することだけです。
llm = Databricks(cluster_driver_port="7777")
llm("How are you?")
Out[1]: "I'm doing well, thank you."
動きました!