1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Gemma 2 2B 日本語ファインチューニング & TPUv3-8 + Kaggle Hub公開

Last updated at Posted at 2024-08-05

このノートブックでは、Googleが新たにリリースした軽量ながらも高性能な言語モデル Gemma 2 2B を、日本語データセット databricks-dolly-15k-ja でファインチューニングする方法を紹介します。さらに、KaggleのTPU v3-8を活用することで、効率的な学習を実現します。ファインチューニング後、モデルをKaggle Hubにアップロードする手順までを解説します。

この記事は、大規模言語モデル(LLM)の学習に興味がある初心者の方々を対象としています。 各ステップで丁寧な解説を加え、コードブロックには詳細なコメントを付与することで、スムーズに理解を進められるように工夫しました。

環境設定

まずは必要なライブラリをインストールし、TPUを使用するための環境設定を行います。

# 必要なライブラリのインストール
!pip install -q -U keras-nlp tensorflow-text
!pip install -q -U tensorflow-cpu
!pip install -q datasets kagglehub kaggle_secrets rich

# Keras JAXバックエンドの設定 (TPUを使用するため)
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"  # TPUメモリの使用効率を最大化

# 必要なモジュールのインポート
import jax
import keras
import keras_nlp
from datasets import load_dataset
import pandas as pd
from rich.console import Console
from rich.tree import Tree
from rich.panel import Panel
from rich.text import Text
from rich.markdown import Markdown
from rich.box import ROUNDED

# Rich Consoleのインスタンスを作成
console = Console()

TPUの確認

Kaggleでは、TPUv3-8デバイスが提供されています。以下のコードで、TPUが正しく認識されているか確認しましょう。

# 利用可能なTPUデバイスの一覧を表示
jax.devices()

出力結果にTPUデバイスが表示されれば、TPUを使用する準備が整っています。

モデルのロードと分散設定

Gemma 2 2Bモデルをロードし、TPUでの分散学習のための設定を行います。

# デバイスメッシュの作成 (4つのTPUコアを使用)
device_mesh = keras.distribution.DeviceMesh(
    (1, 4),  # TPUコアの形状 (行, 列)
    ["batch", "model"],  # 分散する次元
    devices=keras.distribution.list_devices()[:4],  # 最初の4つのデバイスを使用
)

# レイアウトマップの設定 (モデルの重みをTPUコアにどのように配置するかを指定)
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

# モデル並列化の設定
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

# 分散設定の適用
keras.distribution.set_distribution(model_parallel)

# Gemma 2 2Bモデルのロード (事前学習済みの重みを使用)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

解説:

  • デバイスメッシュ: TPUコアをどのように配置するかを定義します。ここでは1行4列のメッシュを作成し、4つのTPUコアを使用します。
  • レイアウトマップ: モデルの各層の重みをどのTPUコアに配置するかを指定します。
  • モデル並列化: モデルを複数のTPUコアに分散して学習するための設定を行います。

モデルの構造確認 (create_layer_tree)

モデルの構造を可視化するために、create_layer_tree 関数を使用します。

# モデルの構造をツリー形式で表示する関数
def create_layer_tree(layer):
    tree = Tree(f"[bold blue]{layer.name}[/bold blue]")
    
    important_attrs = ['_layers', 'transformer_layers', 'layer_norm', '_token_embedding']
    for attr, value in vars(layer).items():
        if attr in important_attrs:
            if isinstance(value, list):
                subtree = tree.add(f"[yellow]{attr}[/yellow]")
                for item in value:
                    subtree.add(f"[green]{item.name}[/green]: {item.__class__.__name__}")
            else:
                tree.add(f"[yellow]{attr}[/yellow]: [green]{value.name}[/green] ({value.__class__.__name__})")
        elif not attr.startswith('_') and not callable(value):
            tree.add(f"[cyan]{attr}[/cyan]: {value}")
    
    return tree

# GemmaDecoderBlockの内部構造を確認
for layer in gemma_lm.layers:
    console.print(Panel(create_layer_tree(layer), title=f"[bold red]{layer.__class__.__name__}[/bold red]", expand=True))

# 埋め込み層の確認
embedding_layer = gemma_lm.get_layer('token_embedding')
console.print(Panel(
    f"[bold magenta]Embedding Layer[/bold magenta]\n"
    f"Name: [green]{embedding_layer.name}[/green]\n"
    f"Shape: [yellow]{embedding_layer.embeddings.shape}[/yellow]",
    title="Embedding Layer Info",
    border_style="blue"
))

データの準備

日本語の指示応答データセット databricks-dolly-15k-ja を使用して、Gemma 2 2Bモデルをファインチューニングします。

# データセットの読み込み
dataset = load_dataset('kunishou/databricks-dolly-15k-ja')
df_databricks = pd.DataFrame(dataset['train'])
df_databricks = df_databricks[["instruction", "output"]]

# ファインチューニング用のデータ形式に変換
data = []
for _, row in df_databricks.iterrows():
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    data.append(template.format(instruction=row['instruction'], response=row['output']))

# データ量を制限 (オプション: リソースが少ない場合はデータ量を減らす)
data = data[:2500]

# データの内容を表示 (最初の3件)
for i, item in enumerate(data[:3], 1):
    console.print(Panel(Text(item), title=f"[bold blue]Data Item {i}[/bold blue]", expand=False))

解説:

  • databricks-dolly-15k-ja: 日本語の指示応答データセット。様々なタスクに対応する指示と応答のペアが含まれています。
  • データ形式の変換: Gemma 2 2Bモデルのファインチューニングに適した形式に変換しています。

ファインチューニング前の推論 (generate_and_display)

ファインチューニングを行う前に、generate_and_display 関数を使用して、モデルの現在の性能を確認しましょう。

# テキスト生成と表示を行う関数
def generate_and_display(prompt, max_length=512):
    generated_text = gemma_lm.generate(prompt, max_length=max_length)
    
    prompt_panel = Panel(
        Text(prompt, style="bold magenta"),
        title="[blue]Input Prompt[/blue]",
        border_style="blue",
        box=ROUNDED,
    )

    generated_md = Markdown(generated_text)
    
    token_count = len(generated_text.split())
    token_info = Text(f"\n\nGenerated {token_count} tokens.", style="italic cyan")

    output_panel = Panel(
        generated_md,
        title="[green]Generated Response[/green]",
        border_style="green",
        box=ROUNDED,
    )

    console.print(prompt_panel)
    console.print()
    console.print(output_panel)
    console.print(token_info)

# 関数を呼び出して結果を表示
generate_and_display("ヴァージン・オーストラリア航空はいつから運航を開始したのですか? ", max_length=512)

LoRA(Low-Rank Adaptation)の設定

LoRA (Low-Rank Adaptation) を使用することで、モデルのパラメータの大部分を凍結し、少数の学習可能なパラメータを追加することで、効率的なファインチューニングを実現します。

# LoRAの有効化 (rank=8: 学習可能なパラメータのランク)
gemma_lm.backbone.enable_lora(rank=8)

# モデルのコンパイル
gemma_lm.preprocessor.sequence_length = 512  # 入力シーケンス長を制限
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=5e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# モデルのサマリを表示
gemma_lm.summary()

解説:

  • LoRA: モデルの大部分を凍結し、少数の学習可能なパラメータを追加することで、効率的なファインチューニングを可能にする手法。
  • rank: LoRAで追加する学習可能なパラメータのランク。

ファインチューニングの実行

準備が整ったので、Gemma 2 2Bモデルのファインチューニングを実行します。

# ファインチューニングの実行
gemma_lm.fit(data, epochs=1, batch_size=4)  # epochs: 学習回数, batch_size: バッチサイズ

解説:

  • epochs: 学習を繰り返す回数。
  • batch_size: 1回の学習で使用するデータの数。

ファインチューニング後の推論

ファインチューニング後のモデルを使って、実際にテキストを生成してみましょう。

# ファインチューニング後の推論
generate_and_display("Instruction:\n日本の首都はどこですか?\n\nResponse:\n", max_length=512)

モデルの保存とKaggle Hubへのアップロード

ファインチューニングしたモデルを保存し、Kaggle Hubにアップロードします。

# モデルの保存先ディレクトリ
FINETUNED_MODEL_DIR = f"./gemma2_2_demo"

# モデル情報
MODEL_BASE = "gemma2_2b_demo"
MODEL_NAME = f"{MODEL_BASE}_train_finetuning_h5"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"
FRAMEWORK = "jax"
VER = 1

# ディレクトリ作成
os.makedirs(FINETUNED_MODEL_DIR, exist_ok=True)

# モデルの重みとトークナイザーのアセットを保存
gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)

# Kaggle Secretsから認証情報を取得
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
KAGGLE_USERNAME = user_secrets.get_secret("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = user_secrets.get_secret("KAGGLE_KEY")
os.environ["KAGGLE_USERNAME"] = KAGGLE_USERNAME

# Kaggle Hubへのアップロード
import kagglehub
handle = f'{KAGGLE_USERNAME}/{MODEL_BASE}/{FRAMEWORK}/{MODEL_NAME}'
kagglehub.model_upload(handle, FINETUNED_WEIGHTS_PATH, license_name='Apache 2.0', version_notes=f'v{VER}')

解説:

  • モデルの保存: ファインチューニングしたモデルの重みとトークナイザーのアセットを保存します。
  • Kaggle Hubへのアップロード: kagglehub ライブラリを使用して、保存したモデルをKaggle Hubにアップロードします。

これで、Gemma 2 2Bモデルの日本語データセットでのファインチューニングとKaggle Hubへのアップロードが完了しました。お疲れ様でした!

📒ノートブック

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?