はじめに
Transformerベースの言語モデルを開発,学習させる業務を行っていると,様々なモデル学習のプロセスパイプラインが考えられます。例えば,
- 大規模なモデルの事前学習 (Pre-Training: PT)
- 事前学習したモデルのドメイン特化のための継続事前学習 (Conteinued Pre-training: CP)
- 事前学習済みのモデルのタスク特化のためのファインチューニング (Fine-Tuning: FT)
- 事前学習済みのモデルのタスク特化のためのプロンプトエンジニアリング
近年LLM,生成モデルが大流行している大LLM時代では,2.や3.を行うのではなく,4.のみを実行して,中〜大規模なリソース消費を行わずに高度な特化モデルの運用や開発が可能になりました。(タスクによるとは思います)
しかしながらPre-LLM時代では,2. -> 3.や3.のみを行うプロセスパイプラインが主流であり,現在でも特定のタスクを解くためのLMのファインチューニングの需要は依然としてあるかと思われます。
そんな中,2024年12月に,Pre-LLM時代を切り開いたとも言えるBERTの現代版,ModernBERTが登場しました。ModernBERTの解説記事は世の中にたくさんありますので,この記事では割愛します。
今回はTransformersのModernBERTの実装において,独自にモデルの定義をして3.のFTを行う際に,困ったことがあったので記事にしてみました。
やりたいこと
ModernBERTが発表された後,SB Intuitionsさんが日本語に特化したModernBERTを開発してくださったので,ModernBERTで日本語のタスクを色々と試せるようになりました。
今回実施しようとしたタスクは,マルチタスクのTokenPredicitonです。
概要
モデルは同時に二つのタスクに対して予測を行います。
以下のような文章
text = "日本で最も高い山は何ですか?"
があったときに,これをトークナイズして,
tokens = ['日本で', '最も高い', '山', 'は何ですか', '?']
それぞれのトークンにラベル付をします。
モデルに入力する際はToken idに変換されます。
今回はマルチタスクなので,二つのラベルクラスを持ちます。
task_1 = [0, 0, 1, 0, 2]
task_2 = [0, 5, 3, 0, 0]
同一トークンにおける異なるタスクで,それぞれにラベルを付与することが許可されていて,シングルタスクではできないような複雑なタスクもこれで解ける可能性があります。
モデルはそれぞれのCrossEntropyLossを計算して,最終的に損失を合算して誤差とし,それを逆伝播させて学習します。
実現方法
冒頭に,今回使ったライブラリのバージョンとPythonバージョンの共有です。
Python==3.12.8
torch==2.7.0+cu126
transformers==4.57.3
そして,ライブラリのimportを最初にしておきます。
import torch
from torch import nn
from transformers import AutoConfig, ModernBertPreTrainedModel
from transformers.models.modernbert.modeling_modernbert import ModernBertPredictionHead
このような実装をしようと思ったら,Tramnsformersで提供されているModenBertTokenClassificationからfrom_pretrainedするだけでは不可能です。
したがって,ModernBertPreTrainedModelを継承して自分でマルチタスク分類用のクラスを書く必要があります。
試しに以下のような実装をします。
class ModernBERTForMultiTokenClassification(ModernBertPreTrainedModel):
def __init__(self, config, num_labels: list[int, int], num_tasks: int):
super().__init__(config)
self.num_labels = num_labels
self.num_tasks = num_tasks
self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier_list = nn.ModuleList(
[nn.Linear(config.hidden_size, self.num_labels[task]) for task in range(self.num_tasks)]
)
# Initialize weights and apply final processing
self.post_init()
ModernBertPreTrainedModelを継承してModernBERTForMultiTokenClassificationという新しいクラスを作成します。初期化メソッド内では以下の変更をしています。
- num_labelsで各タスクの分類ラベル数のリストを受け取る
- num_tasksでタスク数を受け取る
- タスクの数ぶんの出力層を,ModuleListの中にLinearを入れて,最後に結合する
このメソッドの使い方はこんな感じになります。
model_name_or_path = "/path/to/your_model/"
config = AutoConfig.from_pretrained(model_name_or_path)
model = ModernBERTForMultiTokenClassification.from_pretrained(
model_name_or_path,
config=config,
num_labels=[3, 5],
num_tasks=2,
from_tf=False,
ignore_mismatched_sizes=False,
)
num_labelsは,[タスク1の分類数, タスク2の分類数, ...]とに入れましょう。
困ったこと
ところが,このモデルを学習させると損失値にいきなりnanが発生しました。何が起こった?
調査
色々調査した結果,classifier_listの出力層部分の重みが変な値に初期化されていました。
initialize.py実行後,以下のコードを実行して重みの値を見てみます。
param_list = []
for name, param in model.classifier_list.named_parameters():
print(name)
print(param)
param_list.append(param)
0.weight
Parameter containing:
tensor([[ 2.2959e-39, nan, nan, ..., 4.5792e-41,
-2.3621e-29, 4.5792e-41],
[-2.3328e-29, 4.5792e-41, -2.3620e-29, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
1.2981e-38, 0.0000e+00]], requires_grad=True)
0.bias
Parameter containing:
tensor([5.3810e-43, 0.0000e+00, 1.4013e-45], requires_grad=True)
1.weight
Parameter containing:
tensor([[3.4580e-34, 4.5796e-41, 5.0301e+00, ..., 0.0000e+00, 1.2981e-38,
0.0000e+00],
[1.4013e-45, 0.0000e+00, 9.8347e-07, ..., 0.0000e+00, 1.5498e-42,
2.2316e-41],
[0.0000e+00, 0.0000e+00, 1.6115e-43, ..., 0.0000e+00, 1.5512e+01,
2.2316e-41],
[0.0000e+00, 0.0000e+00, 9.8091e-45, ..., 0.0000e+00, 2.4819e+02,
2.2316e-41],
[0.0000e+00, 0.0000e+00, 9.8091e-45, ..., 0.0000e+00, 3.9710e+03,
2.2316e-41]], requires_grad=True)
1.bias
Parameter containing:
tensor([5.3810e-43, 0.0000e+00, 1.4013e-45, 0.0000e+00, 0.0000e+00],
requires_grad=True)
何やら重みとバイアスがnanや0ばっかりになっています...どうやらここが原因のようです。
基本的にModernBERTForMultiTokenClassification.pyの__init()__の最後にself.post_init()にて全てのモデル内の入れ子になっているモジュールを含め,ランダムに初期化されるはずが,されていません。
どうやらテンソルが初期化されているだけ(メモリを確保しただけ)です。つまりtorch.empty(size=(x,y))みたいなことしかされていないようです。
原因
ModernBertPreTrainedModelの実装を見てみます。
class ModernBertPreTrainedModel(PreTrainedModel):
config: ModernBertConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = False
@torch.no_grad()
def _init_weights(self, module: nn.Module):
cutoff_factor = self.config.initializer_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3
def init_weight(module: nn.Module, std: float):
init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)
if isinstance(module, nn.Linear):
if module.bias is not None:
init.zeros_(module.bias)
stds = {
"in": self.config.initializer_range,
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
"embedding": self.config.initializer_range,
"final_out": self.config.hidden_size**-0.5,
}
if isinstance(module, ModernBertEmbeddings):
init_weight(module.tok_embeddings, stds["embedding"])
elif isinstance(module, ModernBertMLP):
init_weight(module.Wi, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ModernBertAttention):
init_weight(module.Wqkv, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ModernBertPredictionHead):
init_weight(module.dense, stds["out"])
elif isinstance(module, ModernBertForMaskedLM):
init_weight(module.decoder, stds["out"])
elif isinstance(
module,
(
ModernBertForSequenceClassification,
ModernBertForMultipleChoice,
ModernBertForTokenClassification,
ModernBertForQuestionAnswering,
),
):
init_weight(module.classifier, stds["final_out"])
elif isinstance(module, nn.LayerNorm):
init.ones_(module.weight)
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, ModernBertRotaryEmbedding):
for layer_type in module.layer_types:
rope_init_fn = module.compute_default_rope_parameters
if module.rope_type[layer_type] != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
この_init_weightsメソッドを辿っていくと最終的にpost_init()で呼ばれるため,ここで初期化の定義をしていると考えてよいです。
あれ,if isinstance(module, ModernBertEmbeddings):からのif文を見ると,この条件に合致しないモジュールのtypeだと何もされずに終わるのか...
念のため,出力層の前にModernBERTの中間層との間に挟んでいるPredictionHeadが初期化されているか確認しましょう。
param_list = []
for name, param in model.head.named_parameters():
print(name)
print(param)
param_list.append(param)
すると,ちゃんと初期化されていました。elif isinstance(module, ModernBertPredictionHead):ここの条件式に引っかかったのでinit_weightで初期化されていますね。
dense.weight
Parameter containing:
tensor([[ 0.0347, 0.0053, -0.0175, ..., -0.0127, 0.0206, 0.0148],
[ 0.0157, 0.0217, -0.0239, ..., 0.0099, 0.0064, 0.0086],
[ 0.0256, 0.0193, 0.0525, ..., 0.0054, 0.0154, 0.0354],
...,
[ 0.0228, 0.0295, -0.0080, ..., 0.0228, 0.0050, -0.0064],
[-0.0021, -0.0117, 0.0015, ..., -0.0031, 0.2021, -0.0140],
[ 0.0014, 0.0139, -0.0011, ..., -0.0132, -0.0137, 0.1523]],
requires_grad=True)
norm.weight
Parameter containing:
tensor([ 8.7500, 9.9375, 11.7500, 13.3750, 9.8750, 13.1875, 13.7500,
8.3750, 9.1875, 11.5000, 14.6875, 11.5625, 12.5625, 13.7500,
9.2500, 10.6250, 13.3125, 11.8125, 15.4375, 11.5625, 8.8750,
15.3125, 10.8125, 10.5000, 14.5625, 15.5625, 15.0625, 14.6875,
13.0625, 14.2500, 10.1250, 14.1250, 11.3125, 14.5000, 11.9375,
10.8750, 12.5000, 12.8125, 9.2500, 12.8125, 11.8750, 10.9375,
11.3750, 10.6250, 10.8750, 9.7500, 15.0625, 13.2500, 9.1250,
9.2500, 14.0000, 9.8750, 14.6875, 13.4375, 13.2500, 11.7500,
14.5625, 15.2500, 11.8750, 16.1250, 11.0000, 11.4375, 16.1250,
10.1875, 9.3125, 12.8125, 8.6875, 13.5625, 9.2500, 11.5625,
9.8125, 11.2500, 14.7500, 9.4375, 8.6875, 12.0000, 14.5625,
8.4375, 12.7500, 11.1250, 14.1875, 8.3750, 11.1875, 9.5000,
10.1875, 16.1250, 9.5000, 16.0000, 4.9062, 17.2500, 9.1875,
13.5000, 14.0000, 12.8125, 18.6250, 11.4375, 13.3125, 14.0000,
9.7500, 11.6250, 8.5000, 10.1250, 14.5000, 10.1875, 16.8750,
9.0625, 12.3750, 14.5000, -38.0000, 16.0000, 14.2500, 8.8125,
12.6250, 13.0625, 14.0000, 9.7500, 13.0625, 8.1875, 9.7500,
10.3750, 10.3750, 9.2500, 12.4375, 13.1875, 11.1250, 9.5000,
14.0625, 13.3750, 10.1875, 9.4375, 13.0000, 10.1875, 11.7500,
10.8750, 13.5625, 15.0000, 15.5000, 12.1875, 9.5625, 8.6875,
9.3750, 15.8750, 8.8750, 13.8750, 14.8750, 9.3750, 8.5625,
12.4375, 13.0000, 14.3750, 11.1875, 14.3125, 11.7500, 11.1875,
13.4375, 15.4375, 10.0000, 9.1875, 12.1875, 13.1875, 9.4375,
11.9375, 12.6875, 15.2500, 14.9375, 11.3750, 9.3750, 14.2500,
9.9375, 14.2500, 9.0000, 9.6250, 7.2188, 9.0625, 12.1250,
13.0625, 13.1875, 13.8750, 11.3125, 13.2500, 10.0000, 10.3125,
10.3750, 9.3750, 12.8125, 8.2500, 11.3125, 10.6875, 10.7500,
13.9375, 10.4375, 10.4375, 12.9375, 14.9375, 12.2500, 11.9375,
13.1875, 10.5625, 12.1875, 14.0625, 12.8125, 17.8750, 8.8750,
14.3125, 13.2500, 11.2500, 13.8125, 13.8750, 10.6250, 8.2500,
17.1250, 11.0625, 9.1875, 10.3125, 12.4375, 7.8125, 8.7500,
10.5000, 13.5625, 16.3750, 9.3125, 13.4375, 15.8125, 10.2500,
11.5625, 14.7500, 13.3750, 14.8125, 12.5625, 12.7500, 13.5000,
9.1875, 11.0625, 9.1875, 14.1875, 9.1875, 10.2500, 14.9375,
13.2500, 12.1875, 14.7500, 14.6250, 9.0625, 14.9375, 13.1250,
8.9375, 11.1250, 9.9375, 9.1250, 14.1250, 13.9375, 16.1250,
9.1875, 9.6875, 10.6875, 12.4375, 32.7500, 13.6250, 10.1250,
12.1250, 14.5625, 12.3750, 11.1250, 12.8125, 9.2500, 15.5625,
10.5625, 12.1250, 10.0000, 13.5625, 10.3125, 14.6250, 12.7500,
14.5000, 10.4375, 13.6875, 11.0000, 14.3125, 16.5000, 12.8125,
16.7500, 14.3750, 10.5000, 14.0000, 14.2500, 11.6250, 10.3750,
18.3750, 8.4375, 10.5625, 8.6250, 11.8125, 8.9375, 11.1875,
11.5625, 14.3750, 7.7500, 12.3750, 13.0000, 10.9375, 12.2500,
13.8750, 11.7500, 9.5000, 9.9375, 10.1875, 11.5625, 15.5000,
10.6875, 8.0625, 13.4375, 11.7500, 10.4375, 12.5000, 12.9375,
9.8750, 11.3750, 9.6875, 10.1875, 16.1250, 12.8750, 14.6875,
11.7500, 11.0625, 13.9375, 13.1250, 8.8750, 12.6250, 14.0625,
11.8125, 11.7500, 10.6250, 15.2500, 11.1250, 9.3125, 12.0625,
14.8125, 10.2500, 12.0625, 10.1875, 8.9375, 7.9375, 10.4375,
9.8750, 16.8750, 10.5625, 14.4375, 11.0625, 17.8750, 12.7500,
13.5000, 12.4375, 14.1875, 13.2500, 10.8750, 11.4375, 8.5000,
10.4375, 9.6875, 8.5625, 12.3125, 9.0625, 9.4375, 14.1250,
8.3125, 12.1875, 14.0625, 12.3125, 9.1875, 13.9375, 9.8750,
12.8125, 13.8125, 13.0000, 9.6875, 12.5625, 15.0625, 8.7500,
13.6250, 8.6250, 11.9375, 11.6250, 14.1875, 13.4375],
requires_grad=True)
一方でマルチタスクトークン分類のために勝手に足した出力層のself.classifier_listは,ModuleList型なのでどこにも引っかからずに,メモリ領域を確保されただけで終わっているようでした。その中に入っている線形層もnn.Linearなので,入れ子のモジュールを走査していても,このif文には引っかかりません。正規化層nn.LayerNormはあるのに...?
対策
モデルを呼び出した後に,以下で明示的に初期化してやります。
model = ModernBERTForMultiTokenClassification.from_pretrained(
model_name_or_path,
config=config,
num_labels=[3, 5],
num_tasks=2,
from_tf=False,
ignore_mismatched_sizes=False,
)
for linear in model.classifier_list:
if hasattr(linear, 'weight'):
nn.init.kaiming_normal_(linear.weight, nonlinearity='relu')
if hasattr(linear, 'bias') and linear.bias is not None:
nn.init.zeros_(linear.bias)
方法はKamigでもXaviorのどちらでもよいと思いますが,初期化の漏れがわかっているmodel.classifier_listに対してだけKamingの初期化で,正規分布で明示的に行います。
このコードを挟んで学習ループを開始したところ,正常にLossが下がることが確認できました!
余談
ちなみに発見した限りTransformerのModernBERTの特有の実装でした。BERTやRoBERTaを確認したところ以下のようになっていました。
例えばBERTなら,
class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": BertLayer,
"attentions": BertSelfAttention,
"cross_attentions": BertCrossAttention,
}
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
super()._init_weights(module)
if isinstance(module, BertLMPredictionHead):
init.zeros_(module.bias)
elif isinstance(module, BertEmbeddings):
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
init.zeros_(module.token_type_ids)
という実装でして,super()._init_weights(module)のようにPreTrainedModelの_init_weightsをメソッド継承しています。なのでBERTで同じような実装を行うと今回の事象は発生しません。
まとめ
- ModernBERTでファインチューニングをする際,Transformerが提供する分類クラスを使わずに自前で実装する際は,新しく足した重みの初期化がされないことがある
- 特に,
ModernBertPreTrainedModelを継承する場合,_init_weightsメソッドの中で定義されていないモジュールは初期化プロセスをすっ飛ばされるので,自分で初期化が必要になる