ある意味、GWの宿題的な取り組みとして。
やりたかったのは以下のフローです。
- rinnaのファインチューニング
- MLflowによるメトリクスやモデルのトラッキング
- モデルサービングによるRESTエンドポイントの構築
- Streamlitからモデルを呼び出し
結論として、上のエンドツーエンドのフローをDatabricksで完結させることができました。嬉しい。
自分の参考資料はこちらです。途中で触れていますが、他の方の記事も参考にさせていただいています。
- DatabricksのモデルサービングでLLMを用いたチャットボットを動かす
- MLflow 2.3のHugging Faceトランスフォーマーのサポートを試す
- [翻訳] Hugging Face transformersのクイックスタート
- Databricksで日本語GPT-2モデルをファインチューニングして文章生成をやってみる
- Databricksでrinnaの日本語GPT-2モデルのファインチューニングを試す
- MLflow 2.3のご紹介:ネイティブLLMのサポートと新機能による強化
ライブラリのインストール
transformersやMLflowは最新バージョンをインストールします。Databricksランタイムも最新の13.0MLです。
%pip install transformers==4.25.1
%pip install sentencepiece
%pip install mlflow==2.3.1
DBRを13.0にするとライブラリインストール時にNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
といったメッセージが出るので念の為kernelを再起動します。
dbutils.library.restartPython()
Tokenizerの設定
(多分)ベースモデルに含まれていないDatabricksの用語などを追加します。ここはまだ試行錯誤の段階です。
from transformers import T5Tokenizer,AutoModelForCausalLM
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
DATABRICKS_TOKENS = ["MLflow", "Databricks", "Delta Lake", "Spark"]
num_added_toks = tokenizer.add_tokens(DATABRICKS_TOKENS)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
rinnaのファインチューニング
引き続き、こちらの記事を参考にさせていただいています。
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead
train_data_path = "/dbfs/FileStore/shared_uploads/takaaki.yayoi@databricks.com/dolly/taka_qiita_cleansed.csv"
# データセットの設定
train_dataset = TextDataset(
tokenizer = tokenizer,
file_path = train_data_path,
block_size = 128 # 文章の長さを揃える必要がある
)
# データの入力に関する設定
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm= False
)
# 訓練に関する設定
training_args = TrainingArguments(
output_dir="/dbfs/tmp/takaaki.yayoi@databricks.com/rinna/output_20230502/", # 関連ファイルを保存するパス
overwrite_output_dir=True, # ファイルを上書きするかどうか
num_train_epochs=3, # エポック数
per_device_train_batch_size=8, # バッチサイズ
logging_steps=100, # 途中経過を表示する間隔
save_steps=800 # モデルを保存する間隔
)
# トレーナーの設定
trainer = Trainer(
model =model,
args=training_args,
data_collator = data_collator,
train_dataset = train_dataset
)
以下が今回の肝です。MLflowにメトリクスとモデルをトラッキングしてもらいます。MLflowでtransformersがサポートされたので実装は簡単です。これで、今後のチューニングが捗ります。GPUクラスターの場合はパイプラインでdevice=0
の設定が必要です。
import mlflow
from transformers import pipeline
with mlflow.start_run():
trainer.train()
task = "text-generation"
architecture = "rinna/japanese-gpt2-small"
# fine tune後のモデルとトークナイザー、GPUを指定
sentence_pipeline = pipeline(
task=task, tokenizer=tokenizer, model=model, device=0
)
# Validate that the overrides function
prompts = ["生成型モデルは", "今日は天気がいいので"]
# validation of config prior to save or log
inference_config = {
"top_k": 2,
"num_beams": 5,
"max_length": 30,
"temperature": 0.62,
"top_p": 0.85,
"repetition_penalty": 1.15,
}
# Verify that no exceptions are thrown
sentence_pipeline(prompts, **inference_config)
mlflow.transformers.log_model(
transformers_model=sentence_pipeline,
artifact_path="my_sentence_generator",
task=task,
inference_config=inference_config,
registered_model_name="taka-hugging-face" # モデルレジストリに登録
)
エポック3
にしているのでしばし待ちます。なお、DBR13.0ですとメトリクスの画面がアップデートされています。
トラッキングされているモデルの画面でlossも確認できます。
最後にモデルレジストリに登録しているので、すぐにREST APIでモデルを呼び出せるようになります。前のバージョンがあるので今回はバージョン3になっています。LLMのチューニングは従来の機械学習より奥深いものですので、これまで以上に試行錯誤が発生すると思います。その際に、MLflowのトラッキングやモデルレジストリを活用することで、パラメーター、メトリクス、モデルのバージョンなどを簡単に管理できるようになります。
バッチ推論
デバッグの目的で、トラッキングされたモデルをロードして推論してみます。
loaded_model = mlflow.transformers.load_model('runs:/4dec269b93074f59a75fb7dcb66b0768/my_sentence_generator')
loaded_model("Databricksとは")
Out[6]: [{'generated_text': 'Databricksとは?parkはビッグデータ構造化データや、大規模データを取り扱う際に直面する最大の課題であるの課題に正面から取り組んでいきます。データ処理とデータサイエンスのゴールは同一です: 機械'}]
まだ怪しいですが、ファイチューニングの成果が現れています。ファインチューニングで使用しているデータは、私のQiitaの記事です。
GUI経由モデルの活用
ここまででは、ノートブックからモデルを呼び出せるだけで、広くユーザーの方に活用いただくというところまで到達していません。そこで、サービングエンドポイントの出番です。これを活用することで、モデルをREST API経由で呼び出せるようになります。詳細手順はこちらに。
エンドポイントに上記バージョン3のモデルをデプロイします。これでURL
経由でモデルにアクセスできるようになります。
GUI、前処理、後処理をstreamlitを使って実装します。これまでに作ったものの使い回しですが。
import streamlit as st
import numpy as np
from PIL import Image
import base64
import io
import os
import requests
import numpy as np
import pandas as pd
import json
st.header('Databricks Q&A Chatbot on Databrikcs')
st.write('''
- [MLflow 2\.3のご紹介:ネイティブLLMのサポートと新機能による強化 \- Qiita](https://qiita.com/taka_yayoi/items/431fa69430c5c6a5e741)
- [MLflow 2\.3のHugging Faceトランスフォーマーのサポートを試す \- Qiita](https://qiita.com/taka_yayoi/items/ad370a7f57c4eae58800)
''')
def score_model(prompt):
# 1. パーソナルアクセストークンを設定してください
# 今回はデモのため平文で記載していますが、実際に使用する際には環境変数経由で取得する様にしてください。
token = "<パーソナルアクセストークン>"
# 2. モデルエンドポイントのURLを設定してください
url = '<URL>'
headers = {'Authorization': f'Bearer {token}'}
#st.write(token)
data_json_str = f"""
{{
"inputs": [
[
"{prompt}"
]
]
}}
"""
data_json = json.loads(data_json_str)
response = requests.request(method='POST', headers=headers, url=url, json=data_json)
if response.status_code != 200:
raise Exception(f'Request failed with status {response.status_code}, {response.text}')
return response.json()
prompt = st.text_input("プロンプト")
if prompt != "":
response = score_model(prompt)
st.write(response['predictions'])
streamlit run chatbot_rinna.py
これで一連のフローを実装できたので、モデルのファインチューニングの沼に潜ります。