LoginSignup
1
1

More than 3 years have passed since last update.

HuggingFace Transformers で BERT (風モデル)を pretrain する(ソースコード解説)

Posted at

概要

前回の続き

Github に pretrain のソースコードを上げた。
https://github.com/elm200/bert_pretrainer/blob/master/pretrain.py

蛇足かもしれないが、このソースコードに解説を加えていく。
HuggingFace Transformers のおかげで、比較的簡単に実装できる。
Tokenizer としては、東北大日本語BERTモデルのものを使用させていただいている。

メイン部分(抄)

Next Sentence Prediction (NSP) をするかどうかの選択

BertForPreTrainingWithoutNSP の実体は、transformers からコピーした modeling_bert.py の BertForPreTraining クラスであり、この中の forward メソッドの一部を書き換えている。具体的には、NSPの部分を loss から外している。
実験した結果 NSP が悪影響を及ぼしていることがあったので、こういうオプションを追加した。

メインループ

create_sent_pairs() で BATCH_SIZE 文の文ペアのリストを作り、encode_sent_pairs() で token id にエンコード。あとは、普通に loss を逆伝播しているだけ。

    if USE_NSP:
        model = BertForPreTraining(config)
    else:
        model = BertForPreTrainingWithoutNSP(config)
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=2e-5)
    model.train()

    for i in range(1, MAX_STEPS + 1):
        optimizer.zero_grad()
        sent_pairs = create_sent_pairs(sents_list, batch_size=BATCH_SIZE)
        encoded = encode_sent_pairs(sent_pairs)
        res = model(
            encoded["input_ids"].to(device),
            token_type_ids=None,
            attention_mask=encoded["attention_mask"].to(device),
            labels=encoded["labels"].to(device),
            next_sentence_label=encoded["next_sentence_label"].to(device),
        )
        loss = res.loss       
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

create_sent_pairs()

sents_listの各要素には日本語 Wikipedia の1つの項目の各文がJSON配列の形で入っている(このデータの作り方は後日説明予定)。seq_pair_ratioの確率で、出力される2文は連続したものとなり、そうでないときは、全く無関係の2文が選ばれる。これは Next Sentence Prediction タスクのための工夫である。

def create_sent_pairs(
    sents_list: list[str], batch_size: int, seq_pair_ratio: float = 0.5
) -> list[tuple[str, str, int]]:
    n_seq_pair = int(batch_size * seq_pair_ratio)
    n_random_pair = batch_size - n_seq_pair
    pairs = []

    for i in range(n_seq_pair):
        while True:
            sents = json.loads(random.choice(sents_list))
            if len(sents) >= 2:
                break
        st = random.randint(0, len(sents) - 2) if len(sents) > 2 else 0
        pair = tuple(sents[st : st + 2]) + (0,)
        pairs.append(pair)

    for i in range(n_random_pair):
        ss1, ss2 = random.sample(sents_list, 2)
        sents1 = json.loads(ss1)
        sents2 = json.loads(ss2)
        pair = (random.choice(sents1), random.choice(sents2), 1)
        pairs.append(pair)

    pairs = random.sample(pairs, len(pairs))
    return pairs

mask_token_ids_and_labels()

Masked Language Model(MLM) タスクのための教師データを自動的に作るメソッド。
トークン1つごとに確率 mask_rate で入力がマスク化され、さらにそのマスク化されたトークンが

  • 80%の確率で[MASK]に置換
  • 10%の確率でランダムに他のトークンに置換
  • 10%の確率でそのまま

となる。 MLM では、マスク化されたトークンに関してのみ loss を集計するようにする。

def mask_token_ids_and_labels(
    token_ids: list[int], mask_rate: float = 0.15
) -> tuple[list[int], list[int]]:
    indexes = []
    for i in range(len(token_ids)):
        r = random.random()
        if r < mask_rate:
            indexes.append(i)
    labels = [IGNORE_ID] * len(token_ids)
    if len(indexes) == 0:
        return token_ids, labels
    labels = np.array(labels)
    indexes = np.array(indexes)
    token_ids = np.array(token_ids)
    masked_token_ids = np.array(token_ids)
    labels[indexes] = token_ids[indexes]
    masked_token_ids[indexes] = MASK_ID
    for i in indexes:
        r = random.random()
        if r < 0.1:
            masked_token_ids[i] = token_ids[i]
        elif r < 0.2:
            masked_token_ids[i] = random.choice(token_ids)
    return masked_token_ids.tolist(), labels.tolist()

次回予告

Wikipedia のダンプからどうやって pretrain 用のデータを作るかということを解説したい。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1