LoginSignup
2
0

More than 1 year has passed since last update.

DatabricksでrinnaのファインチューニングからGUI構築までやってみる

Posted at

ある意味、GWの宿題的な取り組みとして。

やりたかったのは以下のフローです。

  1. rinnaのファインチューニング
  2. MLflowによるメトリクスやモデルのトラッキング
  3. モデルサービングによるRESTエンドポイントの構築
  4. Streamlitからモデルを呼び出し

結論として、上のエンドツーエンドのフローをDatabricksで完結させることができました。嬉しい。

自分の参考資料はこちらです。途中で触れていますが、他の方の記事も参考にさせていただいています。

ライブラリのインストール

transformersやMLflowは最新バージョンをインストールします。Databricksランタイムも最新の13.0MLです。

%pip install transformers==4.25.1
%pip install sentencepiece
%pip install mlflow==2.3.1

DBRを13.0にするとライブラリインストール時にNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.といったメッセージが出るので念の為kernelを再起動します。

Python
dbutils.library.restartPython()

Tokenizerの設定

(多分)ベースモデルに含まれていないDatabricksの用語などを追加します。ここはまだ試行錯誤の段階です。

Python
from transformers import T5Tokenizer,AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
DATABRICKS_TOKENS = ["MLflow", "Databricks", "Delta Lake", "Spark"]
num_added_toks = tokenizer.add_tokens(DATABRICKS_TOKENS)

model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")

print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.

rinnaのファインチューニング

引き続き、こちらの記事を参考にさせていただいています。

Python
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead

train_data_path = "/dbfs/FileStore/shared_uploads/takaaki.yayoi@databricks.com/dolly/taka_qiita_cleansed.csv"

# データセットの設定
train_dataset = TextDataset(
    tokenizer = tokenizer,
    file_path = train_data_path,
    block_size = 128 # 文章の長さを揃える必要がある
)

# データの入力に関する設定
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm= False
)

# 訓練に関する設定
training_args = TrainingArguments(
    output_dir="/dbfs/tmp/takaaki.yayoi@databricks.com/rinna/output_20230502/",  # 関連ファイルを保存するパス
    overwrite_output_dir=True,  # ファイルを上書きするかどうか
    num_train_epochs=3,  # エポック数
    per_device_train_batch_size=8,  # バッチサイズ
    logging_steps=100,  # 途中経過を表示する間隔
    save_steps=800  # モデルを保存する間隔
)

# トレーナーの設定
trainer = Trainer(
    model =model,
    args=training_args,
    data_collator = data_collator,
    train_dataset = train_dataset
)

以下が今回の肝です。MLflowにメトリクスとモデルをトラッキングしてもらいます。MLflowでtransformersがサポートされたので実装は簡単です。これで、今後のチューニングが捗ります。GPUクラスターの場合はパイプラインでdevice=0の設定が必要です。

Python
import mlflow
from transformers import pipeline

with mlflow.start_run():
  trainer.train()

  task = "text-generation"
  architecture = "rinna/japanese-gpt2-small"

  # fine tune後のモデルとトークナイザー、GPUを指定
  sentence_pipeline = pipeline(
    task=task, tokenizer=tokenizer, model=model, device=0
  )

  # Validate that the overrides function
  prompts = ["生成型モデルは", "今日は天気がいいので"]

  # validation of config prior to save or log
  inference_config = {
    "top_k": 2,
    "num_beams": 5,
    "max_length": 30,
    "temperature": 0.62,
    "top_p": 0.85,
    "repetition_penalty": 1.15,
  }

  # Verify that no exceptions are thrown
  sentence_pipeline(prompts, **inference_config)

  mlflow.transformers.log_model(
    transformers_model=sentence_pipeline,
    artifact_path="my_sentence_generator",
    task=task,
    inference_config=inference_config,
    registered_model_name="taka-hugging-face" # モデルレジストリに登録
  )

エポック3にしているのでしばし待ちます。なお、DBR13.0ですとメトリクスの画面がアップデートされています。
Screenshot 2023-05-07 at 10.07.59.png

トラッキングされているモデルの画面でlossも確認できます。
Screenshot 2023-05-07 at 11.18.44.png

最後にモデルレジストリに登録しているので、すぐにREST APIでモデルを呼び出せるようになります。前のバージョンがあるので今回はバージョン3になっています。LLMのチューニングは従来の機械学習より奥深いものですので、これまで以上に試行錯誤が発生すると思います。その際に、MLflowのトラッキングやモデルレジストリを活用することで、パラメーター、メトリクス、モデルのバージョンなどを簡単に管理できるようになります。
Screenshot 2023-05-07 at 11.30.46.png

バッチ推論

デバッグの目的で、トラッキングされたモデルをロードして推論してみます。

Python
loaded_model = mlflow.transformers.load_model('runs:/4dec269b93074f59a75fb7dcb66b0768/my_sentence_generator')
loaded_model("Databricksとは")
Out[6]: [{'generated_text': 'Databricksとは?parkはビッグデータ構造化データや、大規模データを取り扱う際に直面する最大の課題であるの課題に正面から取り組んでいきます。データ処理とデータサイエンスのゴールは同一です: 機械'}]

まだ怪しいですが、ファイチューニングの成果が現れています。ファインチューニングで使用しているデータは、私のQiitaの記事です。

GUI経由モデルの活用

ここまででは、ノートブックからモデルを呼び出せるだけで、広くユーザーの方に活用いただくというところまで到達していません。そこで、サービングエンドポイントの出番です。これを活用することで、モデルをREST API経由で呼び出せるようになります。詳細手順はこちらに。

エンドポイントに上記バージョン3のモデルをデプロイします。これでURL経由でモデルにアクセスできるようになります。
Screenshot 2023-05-07 at 11.38.17.png

GUI、前処理、後処理をstreamlitを使って実装します。これまでに作ったものの使い回しですが。

chatbot_rinna.py
import streamlit as st 
import numpy as np 
from PIL import Image
import base64
import io

import os
import requests
import numpy as np
import pandas as pd

import json

st.header('Databricks Q&A Chatbot on Databrikcs')
st.write('''
- [MLflow 2\.3のご紹介:ネイティブLLMのサポートと新機能による強化 \- Qiita](https://qiita.com/taka_yayoi/items/431fa69430c5c6a5e741)
- [MLflow 2\.3のHugging Faceトランスフォーマーのサポートを試す \- Qiita](https://qiita.com/taka_yayoi/items/ad370a7f57c4eae58800)
''')

def score_model(prompt):
  # 1. パーソナルアクセストークンを設定してください
  # 今回はデモのため平文で記載していますが、実際に使用する際には環境変数経由で取得する様にしてください。
  token = "<パーソナルアクセストークン>"

  # 2. モデルエンドポイントのURLを設定してください
  url = '<URL>'
  headers = {'Authorization': f'Bearer {token}'}
  #st.write(token)
  
  data_json_str = f"""
  {{
  "inputs": [
    [
      "{prompt}"
    ]
  ]
}}
  """

  data_json = json.loads(data_json_str)
   
  response = requests.request(method='POST', headers=headers, url=url, json=data_json)
  if response.status_code != 200:
    raise Exception(f'Request failed with status {response.status_code}, {response.text}')
  return response.json()

prompt = st.text_input("プロンプト")

if prompt != "":
  response = score_model(prompt)
  st.write(response['predictions'])
streamlit run chatbot_rinna.py

動きました!
Screenshot 2023-05-07 at 11.43.41.png

これで一連のフローを実装できたので、モデルのファインチューニングの沼に潜ります。

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0