1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Databricksで日本語GPT-2モデルをファインチューニングして文章生成をやってみる

Last updated at Posted at 2023-05-02

こちらの続きです。単に試行錯誤しながら勉強中な訳で。

今度はこちらの記事を参考に。

ライブラリのインストール

最新のMLflowをインストールしているのは今後の布石です。

%pip install transformers==4.20.1
%pip install sentencepiece
%pip install mlflow==2.3.1

トークナイザーの設定

Databricks固有の単語が<unk>になってしまうのでこちらを参考に。

Tokenizer — transformers 2.11.0 documentation

Python
from transformers import T5Tokenizer,AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small")
DATABRICKS_TOKENS = ["MLflow", "Databricks", "Delta Lake", "Spark"]
num_added_toks = tokenizer.add_tokens(DATABRICKS_TOKENS)

model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-small")

print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.

MLflowの設定

以前と同じようにトラッキングサーバーの設定をします。

Python
import os

os.environ['DATABRICKS_TOKEN'] = dbutils.notebook.entry_point.\
getDbutils().notebook().getContext().apiToken().get()
os.environ['DATABRICKS_HOST'] = "https://" + spark.conf.get("spark.databricks.workspaceUrl")
os.environ['MLFLOW_EXPERIMENT_NAME'] = "/Users/takaaki.yayoi@databricks.com/20230422_rinna/japanese-gpt2-small"
os.environ['MLFLOW_FLATTEN_PARAMS'] = "true"

トレーナーの設定

これは元記事準拠で。

Python
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead

# データセットの設定
train_dataset = TextDataset(
    tokenizer = tokenizer,
    file_path = train_data_path,
    block_size = 128 # 文章の長さを揃える必要がある
)

# データの入力に関する設定
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm= False
)

# 訓練に関する設定
training_args = TrainingArguments(
    output_dir="/dbfs/tmp/takaaki.yayoi@databricks.com/rinna/output_20230502/",  # 関連ファイルを保存するパス
    overwrite_output_dir=True,  # ファイルを上書きするかどうか
    num_train_epochs=3,  # エポック数
    per_device_train_batch_size=8,  # バッチサイズ
    logging_steps=100,  # 途中経過を表示する間隔
    save_steps=800  # モデルを保存する間隔
)

# トレーナーの設定
trainer = Trainer(
    model =model,
    args=training_args,
    data_collator = data_collator,
    train_dataset = train_dataset
)

ファインチューニング

Python
%%time
trainer.train()

文書生成

Python
def getarate_sentences(seed_sentence):
    x = tokenizer.encode(seed_sentence, return_tensors="pt", 
    add_special_tokens=False)  # 入力
    x = x.cuda()  # GPU対応
    y = model.generate(x, #入力
                       min_length=50,  # 文章の最小長
                       max_length=100,  # 文章の最大長
                       do_sample=True,   # 次の単語を確率で選ぶ
                       top_k=50, # Top-Kサンプリング
                       top_p=0.95,  # Top-pサンプリング
                       temperature=1.2,  # 確率分布の調整
                       num_return_sequences=3,  # 生成する文章の数
                       pad_token_id=tokenizer.pad_token_id,  # パディングのトークンID
                       bos_token_id=tokenizer.bos_token_id,  # テキスト先頭のトークンID
                       eos_token_id=tokenizer.eos_token_id,  # テキスト終端のトークンID
                       bad_word_ids=[[tokenizer.unk_token_id]]  # 生成が許可されないトークンID
                       )  
    generated_sentences = tokenizer.batch_decode(y, skip_special_tokens=True)  # 特殊トークンをスキップして文章に変換
    return generated_sentences
Python
seed_sentence = "Databricksとは"  
generated_sentences = getarate_sentences(seed_sentence)  # 生成された文章
for sentence in generated_sentences:
    print(sentence)

<unk>は減りましたがまだまだです。

Databricks とは何か?どのように使うのか?とatabricksジョブはpark を用いて定義されています。クラスタポリシーによるpark は、Cやpark のプロシージャのようなpache parkデータフレームを使用する際にはサポートされるものとします。さらに、クラスターポリシーで明示的に許可された利用量のみを保持することは推奨しません。推奨すべきは、許可されていない使用量の監査可能性を持つク
Databricks とは、例えばythonスクリプトを記述するために、以下のコードを使用することはできません。 make_commit() // eature and above to belowing the username into a  のサンプルatabricks eposのクエリーノートブックで、yorchでythonと連携を行うConmlライブラリを作成します。以下のコードでは、rocessing
Databricks とは、すべてのデータをスキーマ名の一部あるいは全部を置き換えて、これらのテーブルを格納すべきかどうかをコントロールすることができます。このモードは、データに互換性のあるいくつかのオプションを提供しており、例えば、スキーマやテーブルの名前スキーマのようなオプションで、これらのテーブルを切り替える柔軟性を制御できるということを実装しています。この変更によって、arquet、sdf、pandas、avu、scalaを含む、多くのデータタイプのデフォルトのar

上のskip_special_tokensFalseにすると一目瞭然です。そもそも、文もイマイチ。

Databricks とは、<unk> がファイルに対して直接データを読み込むことができるのかを決定するコマンド<unk> atabricksにおける、<unk> atabricksにおける構造化ストリーミングの構築手順を以下に示します。本書では、どのように<unk> parkストリーミングコマンド、データフレームを用いてストリームを<unk> eltaキャッシュするために<unk> eltaキャッシュコマンドを使用するのかを説明します。これらのオプションを用いることで、<unk> parkストリーミングファイルを読み込むための数多くのオプションを追加することができます:<unk> artner Connectによってサポートが
Databricks とはどういうものですか?<unk> elta <unk> akeの<unk> elta <unk> akeでは、データが外部データに流れ込まないように自動的に処理を行うようにします。そして、これらのトランザクションを取り扱えるようにデータを格納する<unk> elta <unk> akeに対して、スケーラブルかつ柔軟性を最大化しませんか?<unk> atabricksを利用されている方々の多くが、<unk> elta <unk> akeとデータ共有の機能に対して疑問に思ったことは一度はある
Databricks とは)はデフォルト値です。この値を大きくすることで、より細かい情報を表現することを可能にしました。さらに言えば、より細かい情報を伝えるために<unk> を用いることで、このクラスター設定におけるユーザービリティとインスタンス数のキー設定を最小化しました。注意<unk> atabricksランタイム8.0以降でクラスターのContified<unk> utputを無効化しました。新規のCatalystと<unk> atabricksジョブを有効化する際には、C

ToDo

  • MLflowインテグレーションの導入(先にこれやらないと効率悪いです)
  • Tokenizerの理解
  • 精度改善

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?