3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

生成AIAdvent Calendar 2024

Day 8

Gemmaで日本語のファインチューニング

Posted at

Gemma Developer Day

2024年10月3日に開催された Gemma の日本語版モデルの紹介です。

英語と日本語解説です!

How to Fine-tuning Gemma: Best Practices

To illustrate fine-tuning the model for a specific task, let's consider the example of generating a random Japanese title based on a user's instruction such as "Write a title". To make this possible, we curated a manageable dataset that could be manually processed. This approach is feasible because Gemma 2 has prior knowledge of general Japanese language patterns, enabling it to adapt to this specific task effectively.

特定のタスクのためにモデルをfine-tuningすることを説明するために、「タイトルを書いてください」といったユーザーの指示に基づいてランダムな日本語タイトルを生成する例を考えてみよう。これを可能にするために、手作業で処理可能な管理可能なデータセットを作成した。このアプローチが実現可能なのは、Gemma 2が一般的な日本語のパターンに関する予備知識を持ち、この特定のタスクに効果的に適応できるためである。

What is Fine-tuning

In the first place, you have to understand what is fine-tuning. It's a specialized form of transfer learning. It involves taking a pre-trained language model - one that has already been exposed to a vast corpus of text data and learned the general patterns and structures of language - and further training it on a smaller, more specific dataset. This additional training allows the model to adapt and refine its knowledge, making it better suited for a particular task or domain.

Imagine you are a skilled gamer who excels at various genres, from action-adventures to strategy games. Fine-tuning is akin to taking you and having you focus intensely on mastering a specific game, like a complex real-time strategy (RTS) title. You already possess a strong foundation of gaming skills and knowledge, but the dedicated practice and study within the RTS genre sharpens your tactics, understanding of game mechanics, and overall proficiency within that particular realm.

Similarly, pre-trained language models have a broad understanding of language, but fine-tuning helps them specialize. By exposing them to a curated dataset relevant to your desired application, you guide the model to learn the nuances and intricacies specific to that domain. It's like giving the model a crash course in the language of your chosen field, enabling it to perform tasks with greater accuracy and fluency.

そもそもFine-Tuningとは何かを理解する必要がある。これは転移学習の特殊な形態である。学習済みの言語モデル(膨大なテキストデータに触れ、言語の一般的なパターンと構造を学習したモデル)を、より小規模で特定のデータセットでさらに学習させる。この追加トレーニングによって、モデルはその知識を適応させ、洗練させ、特定のタスクやドメインにより適したものにすることができる。

あなたが、アクション・アドベンチャーからストラテジー・ゲームまで、さまざまなジャンルを得意とする熟練ゲーマーだとしよう。ファインチューニングとは、複雑なリアルタイムストラテジー(RTS)のような特定のゲームをマスターすることに集中させるようなものだ。あなたはすでにゲームのスキルや知識の強力な基礎を持っているが、RTSジャンルの中で献身的な練習と研究を行うことで、戦術、ゲームメカニクスの理解、そしてその特定の領域における全体的な熟練度が磨かれる。

同様に、事前に訓練された言語モデルは、言語に関する幅広い理解を持っていますが、微調整を行うことで専門性を高めることができます。希望するアプリケーションに関連するキュレートされたデータセットにモデルをさらすことで、そのドメイン特有のニュアンスや複雑さを学習するようモデルを導くことができる。これは、モデルにあなたの選んだ分野の言語を学ばせるようなもので、より正確で流暢にタスクを実行できるようになる。

Set environemnt variables

import os
from google.colab import userdata, drive

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")

# Mounting gDrive for to store artifacts
#drive.mount("/content/drive")

Install dependencies

!pip install -q -U keras-nlp datasets
!pip install -q -U keras

# Set the backbend before importing Keras
os.environ["KERAS_BACKEND"] = "jax"
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

import keras_nlp
import keras

# Run at half precision.
#keras.config.set_floatx("bfloat16")

# Training Configurations
token_limit = 128
num_data_limit = 100
lora_name = "my_lora"
lora_rank = 4
lr_value = 1e-3
train_epoch = 5
model_id = "gemma2_instruct_2b_jpn"

Load Model

Why Fine-tuning?

Before embarking on fine-tuning, it's crucial to evaluate if its benefits align with the specific requirements of your application. Fine-tuning involves meticulous data preparation and extensive training, making it an arduous process. Therefore, it's essential to assess whether the potential gains justify the significant effort required.

fine-tuningに着手する前に、その利点がアプリケーションの特定の要件に合致しているかどうかを評価することが重要です。fine-tuningには、綿密なデータ準備と広範なトレーニングが必要であり、骨の折れるプロセスです。したがって、潜在的な利益が、必要な多大な労力を正当化できるかどうかを評価することが不可欠です。

Try "Prompt Engineering" first.

Would you like to enable Gemma's multilingual capabilities? Please note that Gemma 2 already has some multilingual capabilities. Here's the example output from Gemma 2 2B instruction-tuned model.

Do you wish to adjust the tone or writing style? Gemma 2 might be familiar with the writing style you have in mind. Here's another output from the same model.

Gemmaの多言語機能を有効にしますか? Gemma 2にはすでに多言語機能があります。以下は、Gemma 2 2Bインストラクションチューンドモデルの出力例です。 口調や文体を調整したいですか? Gemma 2は、あなたが考えている文体に慣れているかもしれません。同じモデルからの別の出力です。

import keras
import keras_nlp

import time

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
gemma_lm.summary()

tick_start = 0

def tick():
    global tick_start
    tick_start = time.time()

def tock():
    print(f"TOTAL TIME ELAPSED: {time.time() - tick_start:.2f}s")

def text_gen(prompt):
    tick()
    input = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
    output = gemma_lm.generate(input, max_length=token_limit)
    print("\nGemma output:")
    print(output)
    tock()

# inference before fine-tuning
text_gen("Translate the text below to Japanese.\n\"Hi, how can I get to the Tokyo museum?\"")
text_gen("Speak like a pirate. Teach me why the earth is flat.")
text_gen("Write a title")
text_gen("Write a poem")

Tokenizer designed with multilingual in mind

Another important thing is the tokenizer.

Gemma tokenizer is based on SentencePiece. The size of the vocabulary is predetermined before training. SentencePiece then learns the optimal subword segmentation based on the chosen vocabulary size and the training data. Gemma's large 256k vocabulary allows it to handle diverse text inputs and potentially improve performance on various tasks, e.g. handling multilingual text inputs.

(example text: “Hi, Nice to meet you. The weather is really nice today.”)

もうひとつ重要なのはトークナイザーだ。 GemmaのトークナイザーはSentencePieceをベースにしている。語彙のサイズは学習前にあらかじめ決められている。そしてSentencePieceは、選択された語彙サイズと学習データに基づいて最適なサブワード分割を学習します。Gemmaは256kの大きな語彙を持つため、多様なテキスト入力を扱うことができ、多言語テキスト入力の処理など、様々なタスクのパフォーマンスを向上させる可能性がある。 (例:「Hi, Nice to meet you. The weather is really nice today.」)

tokenizer = keras_nlp.models.GemmaTokenizer.from_preset(model_id)
import jax

def detoken(tokens):
  print(tokens)
  for x in tokens:
    word = tokenizer.detokenize(jax.numpy.array([x]))
    print(f"{x:6} -> {word}")

detoken(tokenizer("こんにちは。初めまして。今日は本当にいい天気ですね。"))

Load Dataset

How many datasets do you need? You can start with relatively small number of datasets, approximately 10 to 20, those can have a significant impact on a model's behavior.

For improved the output quality, a target of around 200 total examples is recommended. Nevertheless, the amount of data required for tuning really depends on how much you want to influence the model's behavior. Our recommendation is to commence with a limited amount of data and gradually incorporate additional data into the training process until the desired behavior is achieved.

いくつのデータセットが必要ですか?モデルの挙動に大きな影響を与える可能性があります。 出力品質を向上させるためには、200例程度を目標にすることをお勧めします。とはいえ、チューニングに必要なデータ量は、モデルの挙動にどの程度影響を与えたいかに大きく左右されます。私たちが推奨するのは、限られたデータ量で開始し、望ましい挙動が達成されるまで、徐々に追加データを学習プロセスに組み込むことです。

tokenizer = keras_nlp.models.GemmaTokenizer.from_preset(model_id)

# example titles
data = [
    "星空の旅人",   # (Hoshizora no Tabibito) - Starry Sky Traveler
    "沈黙の剣",     # (Chinmoku no Tsurugi) - Silent Sword
    "風の歌",       # (Kaze no Uta) - Song of the Wind
    "永遠の炎",     # (Eien no Honoo) - Eternal Flame
    "幻の都市",     # (Maboroshi no Toshi) - Phantom City
    "深海の秘密",   # (Shinkai no Himitsu) - Secret of the Deep Sea
    "運命の出会い", # (Unmei no Deai) - Destined Encounter
    "禁断の果実",   # (Kindan no Kajitsu) - Forbidden Fruit
    "影の狩人",     # (Kage no Karyudo) - Shadow Hunter
    "希望の光",     # (Kibou no Hikari) - Light of Hope
    "蒼い海の伝説", # (Aoi Umi no Densetsu)
    "黒曜石の涙",   # (Kokuyouseki no Namida) - Obsidian Tears
    "砂漠の薔薇",   # (Sabaku no Bara) - Desert Rose
    "黄金の鍵",     # (Ougon no Kagi) - Golden Key
    "嵐の夜に",     # (Arashi no Yoru ni) - On a Stormy Night
]

train = []

for x in data:
  item = f"<start_of_turn>user\nWrite a title<end_of_turn>\n<start_of_turn>model\n{x}<end_of_turn>"
  length = len(tokenizer(item))
  # skip data if the token length is longer than our limit
  if length < token_limit:
    train.append(item)
    if(len(train)>=num_data_limit):
      break

print(len(train))
print(train[0])
print(train[1])
print(train[2])

LoRA Fine-tuning

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=lora_rank)
gemma_lm.summary()

# Limit the input sequence length (to control memory usage).
gemma_lm.preprocessor.sequence_length = token_limit
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=lr_value,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

class CustomCallback(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    #model_name = f"/content/drive/MyDrive/{lora_name}_{lora_rank}_epoch{epoch+1}.lora.h5"
    #gemma_lm.backbone.save_lora_weights(model_name)

    # Evaluate
    text_gen("Write a title")
    text_gen("Write a poem")

history = gemma_lm.fit(train, epochs=train_epoch, batch_size=2, callbacks=[CustomCallback()])

import matplotlib.pyplot as plt
plt.plot(history.history['loss'])
plt.show()

Try a different sampler

gemma_lm.compile(sampler="top_k")
text_gen("Write a title")
text_gen("Write a title")
text_gen("Write a title")
text_gen("Write a title")
text_gen("Write a title")
3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?