はじめに
Wiki-40Bデータセットを使ってゼロから日本語BERT事前学習モデルを構築してみました。ライブラリはHugging FaceのTransformersを使用しています。備忘録として、残しておきます。
以下が手順になります。
- 環境構築
- データセットのダウンロード
- データセットの前処理
- トークナイザーの学習
- BERTの事前学習
BERTとは?
BERTはBidirectional Encoder Representations from Transformersの略で、2018年にGoogleによって開発された自然言語処理の事前学習モデルです。BERTは、テキストの事前学習を通じて、さまざまな自然言語処理タスクでSOTAを達成しました。
BERTは、Transformerと呼ばれるニューラルネットワークのアーキテクチャを使用しており、双方向エンコーダーと呼ばれる複数のレイヤーで構成されています。BERTは、大量のテキストデータを事前学習することにより、文脈を考慮した自然言語処理タスクに優れた性能を発揮することができます。BERTは、質問応答、自然言語生成、文書分類などのタスクで広く使用されています。
1. 環境構築
今回はAnacondaを使用して、仮想環境を構築しました。検証した時点でのtransformersや関連するライブラリのバージョンは以下の通りです。
pip list | grep transformers
# transformers 4.27.1
pip list | grep tokenizers
# tokenizers 0.13.2
pip list | grep torch
# torch 2.0.0+cu118
pip list | grep tensorflow
# tensorflow 2.11.1
# tensorflow-datasets 4.8.3
実験環境については、以下を使用しています。
Ubuntu 20.04
Specification
CPU: Intel Core i9-13900K (24 cores, 32 threads, 3.0 / 2.8GHz, Passmark 59763)
GPU: GeForce RTX-3090 Ti
2. データセットのダウンロード
データセットは、TensorFlowDatasetsのWiki-40Bを使用しました。Wiki-40Bは、40言語以上のWikipediaを前処理して作られたデータセットで、以下のように、訓練/検証/テスト用データセットに分かれています。
・訓練データ:74万5392件
・検証データ:4万1576件
・テストデータ:4万1268件
・合計82万8236件(2.19GB)
データはtrain(90%)/validation(5%)/test(5%)に3分割されています。
# データセットの取得
import tensorflow_datasets as tfds
ds_train = tfds.load('wiki40b/ja', split='train')
# データセットをテキスト形式で出力する関数
def create_txt(file_name, tf_data):
start_paragraph = False
# ファイルの書き込み
with open(file_name, 'w') as f:
for wiki in tf_data.as_numpy_iterator():
for text in wiki['text'].decode().split('\n'):
if start_paragraph:
text = text.replace('_NEWLINE_', '') # _NEWLINE_は削除
f.write(text + '\n')
start_paragraph = False
if text == '_START_PARAGRAPH_': # _START_PARAGRAPH_のみ取得
start_paragraph = True
# データセットをテキスト形式で出力
create_txt('data/wiki_40b_train.txt', ds_train)
3. データセットの前処理
./dataに移動して、wiki_40b_train.txtがあるか確認し、前処理をするためにpreprocess.shを用意します。preprocess.shでは、以下の処理をしています。
- 行末の空白は除去、空白のみの行は削除
- "。” の後が"」"、")“、")”,“]"だった場合、"。"の後で改行
- "。"で始まる行は削除
FILE=$1
if [ $# -ne 1 ]; then
echo "Usage: ./preprocess.sh INPUT_TEXT"
exit 1
fi
echo "Processing ${FILE}"
sed -i -e '/<doc id/,+1d; s/<\/doc>//g' ${FILE}
sed -i -e 's/ *$//g; s/。\([^」|)|)|"]\)/。\n\1/g; s/^[ ]*//g' ${FILE}
sed -i -e '/^。/d' ${FILE}
preprocess.shができたら、以下のコマンドを実行して前処理を実行します。
chmod u+x preprocess.sh
./preprocess.sh wiki_40b_train.txt
4. トークナイザーの学習
モデルを学習するために、テキストをトークン化を行います。ほとんどのトランスフォーマーモデルには、事前に学習されたトークナイザーが付属していますが、モデルをゼロから事前学習するため、トーケナイザーをインポートしたデータ上で学習します。
from sentencepiece import SentencePieceTrainer
data_dir = "./data/"
SentencePieceTrainer.Train(
'--input='+data_dir+'corpus/wiki_40b_train.txt, --model_prefix='+data_dir+'wiki40b_sentencepiece --character_coverage=0.9995 --vocab_size=32000 --pad_id=3 --add_dummy_prefix=False'
)
5. BERTの事前学習
from transformers import BertConfig
from transformers import BertForMaskedLM
from transformers import LineByLineTextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import AlbertTokenizer
from transformers import TrainingArguments
from transformers import Trainer
import matplotlib.pyplot as plt
import scienceplots
import matplotlib
matplotlib.get_cachedir()
plt.style.use(['science', 'ieee', 'no-latex'])
matplotlib.rc('font', family='times new roman')
data_dir = "./data/"
tokenizer = AlbertTokenizer.from_pretrained(data_dir+'wiki40b_sentencepiece.model', keep_accents=True)
config = BertConfig(vocab_size=32003, num_hidden_layers=12, intermediate_size=768, num_attention_heads=12)
model = BertForMaskedLM(config)
dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path=data_dir + 'corpus/train_data.txt',
block_size=256,
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability= 0.15
)
training_args = TrainingArguments(
output_dir= data_dir + 'SousekiBERT/',
overwrite_output_dir=True,
num_train_epochs=100,
per_device_train_batch_size=32,
save_steps=10000,
save_total_limit=2,
prediction_loss_only=True
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset
)
trainer.train()
epoch_lst = []
loss_lst = []
for log in trainer.state.log_history:
try:
loss_lst.append(log['loss'])
epoch_lst.append(log['epoch'])
except KeyError:
pass
fig, ax = plt.subplots()
ax.plot(epoch_lst, loss_lst)
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
plt.savefig("log/output.png")
trainer.save_model('wiki40b_BERT/')
6. 実行結果
こちらの結果は、wiki_40b_train.txtのデータすべてを使用したのではなく、先頭から1万行抽出したデータを100エポック学習した結果になっています。学習には、10時間くらいかかったと思います(たぶん...)。すべての学習データを使用すると、わずか4エポックで、36時間くらいかかりました💦
サンプルコードを公開してあります👇