15
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

サイバーエージェントの日本語LLM OpenCALMをDatabricksでファインチューニングしてみる

Posted at

こちらの続きです。

これまでにも他のモデルでファインチューニングしてはいますが、とりあえず動くというレベルでした。今一度ドキュメントを勉強してトライしてみます。

そのために、結構Hugging Faceのドキュメントを読みました。まだまだ奥深いですが。

そして、引き続き先人たちの知恵を借りています。今回特に参考にさせていただいたのはこちら。

こちらの記事を写経させていただきつつ、自分の理解をメモしていきます。使っているのはDatabricksランタイム13.0MLとg5.4xlarge(64GBメモリー、1GPU)のクラスターです。トレーニングデータセットは、私のQiita記事です。

ライブラリのインストール

%pip install transformers==4.25.1
%pip install mlflow==2.3.1
Python
dbutils.library.restartPython()

ライブラリのインポート

Python
import csv, json, re, os, sys, pickle
import numpy as np
from tqdm import tqdm

import torch
import evaluate
from transformers import AutoModelForCausalLM, AutoTokenizer, T5Tokenizer
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead

変数設定

ここでは、ワークスペースファイルの機能を使って、出力ファイルなどをノートブックと同じパスにしています。

あと、ここで気にすべき変数は:

  • 文章の最小長、最大長(min_len, max_len): これがファインチューニングの際の入力を決定する。多くの場合、大きくした方が良いがOut of Memoryになりがち。入力データの形式も見直し必要かもしれない。今は一行がほぼ一記事で長いので。
  • エポック数(epoch): トレーニング回数増やした方が良いのでしょうが、今は動作確認が先なので1エポックにしています。
  • 予測時のパラメータ: これも大事。ファインチューニングしたモデルの挙動を調整するためには更なる調査が必要。
  • use_cpu: 内部ではPyTorchやTF使っているので、CPU/GPUにも注意が必要。
Python
# モデル名
model_name = "cyberagent/open-calm-medium"

# ファイチューニング用データセット
fine_tune_csv_path = "/dbfs/FileStore/shared_uploads/takaaki.yayoi@databricks.com/dolly/taka_qiita_cleansed.csv"

# クレンジング済みファインチューニングデータセットのパス
training_data_path = "/Workspace/Users/takaaki.yayoi@databricks.com/20230517_opencalm/train_data_file.txt"

# 文章の最小長、最大長
min_len = 32
max_len = 100

# エポック数
#epoch = 100
epoch = 1
# パッチサイズ
batch_size = 8

# 途中経過を表示する学習回数の間隔
logging_steps = 200

# モデルを保存する間隔
save_freq = 100000

# 結果の出力先
#output_dir = "./ouput_dir/"
output_dir = "/Workspace/Users/takaaki.yayoi@databricks.com/20230517_opencalm/opencalm_output/"

# 予測時のパラメータ
top_k = 40 # top-k検索の閾値
top_p = 1 # top-pの閾値
num_text = 1 #出力する文の数
temp = 1.0
repeat_ngram_size = 1

# 推論にCPUを使用するか
use_cpu = False

GPU使っているので、以下のコマンドは

Python
print("GPU is available : {}".format(torch.cuda.is_available()))
GPU is available : True

となります。

トークナイザー

モデルに同梱されているものを使用。

Python
tokenizer = AutoTokenizer.from_pretrained(model_name)

モデル

Python
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

データセットのロード

Python
# csvファイルの読み込み
data = []

with open(fine_tune_csv_path, "r") as f:
  reader = csv.reader(f, delimiter = ",")
  for row in reader:
    data.append([row[0]])

dataを確認します。

Out[7]: [['これまではノートブック名による検索しかできませんでしたが、今回のエンハンスでノートブックの中身に対しても検索ができるようになりました。プレビュー本機能はパブリックプレビューです。注意ここで説明する検索機能は、暗号化に顧客管理キーを用いているワークスペースではサポートされていません。これらのワークスペースにおいては、サイドバーのSearchをクリックし、Search Workspaceフィールドに検索文字列をタイプします。タイプするたびに、名前に検索文字列を含むオブジェクトが一覧されます。ワークスペースでオブジェクトを開くには名前をクリックします。サイドメニューの検索をクリックします。検索ダイアログが表示されます。検索キーワードを入力しEnterを押すと、検索結果が表示されます。名称をクリックすることでオープンすることができます。ドロップダウンから種別(Type)を選択することで絞り込みを行うことができます。'],
 ['Databricksとdbt Cloudの連携のステップ3: モデルを作成して実行するまでを実践した内容です。クラスターの準備ここでは、Databricksクラスターに接続します。事前にパーソナルアクセストークンを取得しておきます。サイドメニューの設定 > ユーザー設定 > アクセストークンに移動し、新規トークンの作成をクリックして、トークンの名前、存続期間を指定します。作成をクリックするとトークンが表示されるのでコピーしておきます。注意 パーソナルアクセストークンは厳重に管理してください。第三者に教えたりしないでください。任意のスペックのクラスターを作成します。dbtからこのクラスターにアクセスすることになります。高度なオプションのJDBC/ODBCタブに表示されるサーバーのホスト名、ポート、HTTPパスをメモしておきます。データベースとテーブルの準備...

テキストのクレンジング

要件に応じて処理を追加、変更します。

Python
def text_cleansing(texts):
  # タブ、改行、改ページを削除
  texts = [re.sub("[\u3000 \t \s \n]", "", t) for t in texts]

  return texts
Python
with open(training_data_path, "w") as f:
  for row in tqdm(data):
    ret = text_cleansing(row)
    text = ret[0]

    # no cleansing
    #text = row[0]

    print(text)
    f.write(text + "\n")

トレーニングデータセットの設定

TextDatasetはまもなくdeprecatedなので、Datasetsへの移行が必要です。

Python
# データセットの設定
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path=training_data_path,
    block_size=max_len,  # 文章の長さを揃える必要がある
)

Data Collatorの設定

そもそもData Collatorって何?と思いましたが、バッチでデータロードしてくれるそうで。

Data Collator

Data collatorは、入力としてデータセット要素のリストを用いて、バッチ処理を行うオブジェクトです。これらの要素は、train_datasetあるいはeval_datasetの要素として同じタイプのものとなります。

トレーニングの設定

Python
os.makedirs(output_dir,  exist_ok=True)
Python
training_args = TrainingArguments(
    output_dir = output_dir + "opencalm",
    overwrite_output_dir = True,
    num_train_epochs = epoch,
    per_device_train_batch_size = batch_size,
    logging_steps = logging_steps,
    save_steps = save_freq,
)

トレーナー

上の設定やデータ、Data Collator、モデルを参照してトレーナーを作成します。

Python
trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    train_dataset = train_dataset,
)

トレーニング

Python
trainer.train()

MLflowのオートロギングが動作するので、自動でメトリクスやパラメーターが記録されます。
Screenshot 2023-05-21 at 15.11.38.png

モデルの保存

transformersの機能、あるいはMLflowのモデルロギングで記録できます。自分は最終的にはMLflowに寄せていきます。

transformerの機能で保存

Python
# モデルをCPU/GPUのどちらかに移す
if use_cpu:  # 推論時にCPUの利用を強制する場合の処理
    device = torch.device("cpu")
else:  # 特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model.to(device)

# モデルを保存する
trainer.save_model()

MLflowの機能で保存

ただ、こちらは推論の設定を見直す必要があります。

Python
import mlflow
from transformers import pipeline

with mlflow.start_run():
  trainer.train()

  task = "text-generation"

  # fine tune後のモデルとトークナイザー、GPUを指定
  sentence_pipeline = pipeline(
    task=task, tokenizer=tokenizer, model=model, device=0
  )

  prompts = ["生成型モデルは", "今日は天気がいいので"]

  # 推論の設定
  inference_config = {
    "top_k": top_k,
    "num_beams": 5,
    "max_length": max_len,
    "temperature": temp,
    "top_p": top_p,
    "repetition_penalty": 1.15,
  }

  # 例外が起きないことを確認
  sentence_pipeline(prompts, **inference_config)

  mlflow.transformers.log_model(
    transformers_model=sentence_pipeline,
    artifact_path="my_sentence_generator",
    task=task,
    inference_config=inference_config,
    #registered_model_name="taka-hugging-face" # モデルレジストリに登録
  )

推論

Python
def generate_response(prompt):
    # 文章をtokenizerでエンコード
    x = tokenizer.encode(prompt, return_tensors="pt")

    if use_cpu: #CPUの利用を強制する場合の処理
      device = torch.device('cpu')
    else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
      device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    x = x.to(device)
    # 推論
    with torch.no_grad():
      y = model.generate(
                        x,
                        min_length=min_len,  # 文章の最小長
                        max_length=max_len,  # 文章の最大長
                        do_sample=True,   # 次の単語を確率で選ぶ
                        top_k=top_k, # Top-Kサンプリング
                        top_p=top_p,  # Top-pサンプリング
                        temperature=temp,  # 確率分布の調整
                        no_repeat_ngram_size = repeat_ngram_size, #同じ単語を何回繰り返していいか
                        num_return_sequences=num_text,  # 生成する文章の数
                        pad_token_id=tokenizer.pad_token_id,  # パディングのトークンID
                        bos_token_id=tokenizer.bos_token_id,  # テキスト先頭のトークンID
                        eos_token_id=tokenizer.eos_token_id,  # テキスト終端のトークンID
                        early_stopping=True
                      )
    
    # 特殊トークンをスキップして推論結果を文章にデコード
    res = tokenizer.batch_decode(y, skip_special_tokens=True)
    return res

推論します。

Python
ret  = generate_response(prompt = "<s>MLflowとは")
print(ret)
['<s>MLflowとは?様々な方法で自動でモデルを登録できるようにし、それぞれの特徴量に対して適切な分類を行うためにトレーニングデータを収集します。モデルをモジュール化する機械学習ライブラリです。(本書より)本章ではこれらを学ぶ過程で理解するポイントを見ていきますDeltaLiveTablesはAWSのGCPにあるのでまずはここからトライしてみました!DatabricksはPythonベースで動作するクラウドネイティブパッケージング基盤(CDCの基盤)。そしてGithubで公開されたVersion1.17がインストールされているDocker環境']

後半はあれですが、それっぽく回答が返ってきています!まだ、エポック1ですし。

今後のアクション

  • MLflow連携(ロギング、サービング)
  • エポック数増やした際の精度の変化
  • 推論パラメーターの理解
  • 入力データの見直し
    • サイズと精度のトレードオフ
    • クレンジング
  • deepspeedなどを用いたチューニングの高速化
  • プロンプトエンジニアリングの理解
  • 問題タイプ(Q&Aなど)ごとのチューニング手法の理解

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

15
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?