0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

transformersとaozorabunko-cleanで作る日本語ModernBERTモデル

Posted at

transformersaozorabunko-cleanで、日本語ModernBERTモデルを作ってみることにした。ただ、ModernBERT-baseのトークナイザは日本語にあまりに不向きなので、昨年1月2日の記事で作ったDebertaV2TokenizerFastを使い回すことにした。また、ModernBERTの入力トークン幅は8192と大きいので、train.txtも各行10000文字程度にした。

#! /usr/bin/python3
#pip3 install transformers accelerate deepspeed triton datasets fugashi unidic-lite
import os,json
os.system("""
if test -d transformers
then :
else git clone --depth=1 https://github.com/huggingface/transformers transformers-all
     ln -s transformers-all/src/transformers transformers
     sed 's/-> \\(.*\\) | \\(.*\\):/-> Union[\\1, \\2]:/' transformers/models/modernbert/modeling_modernbert.py > modeling_modernbert.py
     cp modeling_modernbert.py transformers/models/modernbert
fi
test -d ModernBERT-base || git clone --depth=1 https://huggingface.co/answerdotai/ModernBERT-base
test -f ModernBERT-base/configuration_modernbert.py || sed 's/^from \\.\\.\\./from transformers./' transformers/models/modernbert/configuration_modernbert.py > ModernBERT-base/configuration_modernbert.py
test -f ModernBERT-base/modeling_modernbert.py || sed -e 's/^from \\.\\.\\./from transformers./' -e 's/^from .* import is_triton_available/import importlib\\nis_triton_available = lambda: importlib.util.find_spec("triton") is not None/' transformers/models/modernbert/modeling_modernbert.py > ModernBERT-base/modeling_modernbert.py
""")
with open("ModernBERT-base/config.json","r",encoding="utf-8") as r:
  d=json.load(r)
if not "auto_map" in d:
  d["auto_map"]={
    "AutoConfig":"configuration_modernbert.ModernBertConfig",
    "AutoModel":"modeling_modernbert.ModernBertModel",
    "AutoModelForMaskedLM":"modeling_modernbert.ModernBertForMaskedLM",
    "AutoModelForSequenceClassification":"modeling_modernbert.ModernBertForSequenceClassification",
    "AutoModelForTokenClassification":"modeling_modernbert.ModernBertForTokenClassification"
  }
  with open("ModernBERT-base/config.json","w",encoding="utf-8") as w:
    json.dump(d,w,indent=2)
if not os.path.isfile("train.txt"):
  import datasets
  with open("train.txt","w",encoding="utf-8") as w:
    d,i=datasets.load_dataset("globis-university/aozorabunko-clean"),0
    for t in d["train"]:
      for s in t["text"].replace("。","。\n").replace("\u3000"," ").split("\n"):
        if i+len(s)<10000:
          print(s,end="",file=w)
          i+=len(s)
        else:
          print("\n"+s,end="",file=w)
          i=len(s)
    print("",file=w)
os.system("test -s token.txt || fugashi -Owakati < train.txt > token.txt")

from transformers import DebertaV2TokenizerFast
if not os.path.isfile("tokenizer.json"):
  import urllib.request
  from tokenizers import Tokenizer,models,pre_tokenizers,normalizers,processors,decoders,trainers
  with urllib.request.urlopen("https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt") as r:
    joyo=[chr(int(t,16)) for t in r.read().decode().strip().split("\n") if not t.startswith("#")]
  spt=Tokenizer(models.Unigram())
  spt.pre_tokenizer=pre_tokenizers.Sequence([pre_tokenizers.Whitespace(),pre_tokenizers.Punctuation()])
  spt.normalizer=normalizers.Sequence([normalizers.Nmt(),normalizers.NFKC()])
  spt.post_processor=processors.TemplateProcessing(single="[CLS] $A [SEP]",pair="[CLS] $A [SEP] $B:1 [SEP]:1",special_tokens=[("[CLS]",0),("[SEP]",2)])
  spt.decoder=decoders.WordPiece(prefix="",cleanup=True)
  spt.train(trainer=trainers.UnigramTrainer(vocab_size=65000,max_piece_length=4,initial_alphabet=joyo,special_tokens=["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"],unk_token="[UNK]",n_sub_iterations=2),files=["token.txt"])
  spt.save("tokenizer.json")
tkz=DebertaV2TokenizerFast(tokenizer_file="tokenizer.json",split_by_punct=True,do_lower_case=False,keep_accents=True,vocab_file="/dev/null")
tkz.save_pretrained("modernbert-base-japanese-aozora")
with open("train.py","w",encoding="utf-8") as w:
  print('''#! /usr/bin/env deepspeed
from transformers import DebertaV2TokenizerFast,ModernBertForMaskedLM,AutoConfig,DataCollatorForLanguageModeling,TrainingArguments,Trainer
tkz=DebertaV2TokenizerFast.from_pretrained("modernbert-base-japanese-aozora")
c={"trust_remote_code":True,"vocab_size":len(tkz),"tokenizer_class":type(tkz).__name__}
for k,v in tkz.special_tokens_map.items():
  c[k+"_id"]=tkz.convert_tokens_to_ids(v)
cfg=AutoConfig.from_pretrained("ModernBERT-base",**c)
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=1,output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2,save_safetensors=False)
class ReadLineDS(object):
  def __init__(self,file,tokenizer):
    self.tokenizer=tokenizer
    with open(file,"r",encoding="utf-8") as r:
      self.lines=[s.strip() for s in r if s.strip()>""]
  __len__=lambda self:len(self.lines)
  __getitem__=lambda self,i:self.tokenizer(self.lines[i],truncation=True,add_special_tokens=True,max_length=8190)
trn=Trainer(args=arg,data_collator=DataCollatorForLanguageModeling(tkz),model=ModernBertForMaskedLM(cfg),train_dataset=ReadLineDS("train.txt",tkz))
trn.train()
trn.save_model("modernbert-base-japanese-aozora")''',file=w)
os.system("""
chmod 755 train.py
./train.py
cp ModernBERT-base/*.py modernbert-base-japanese-aozora
""")

from transformers import AutoTokenizer,AutoModelForMaskedLM,FillMaskPipeline
tkz=AutoTokenizer.from_pretrained("modernbert-base-japanese-aozora")
mdl=AutoModelForMaskedLM.from_pretrained("modernbert-base-japanese-aozora",trust_remote_code=True)
fmp=FillMaskPipeline(model=mdl,tokenizer=tkz)
print(fmp("夜の底が[MASK]なった。"))

NVIDIA A100-SXM4-40GBを8枚使って頑張ったが、箱根駅伝の往路に間に合わず、5時間24分ほどかかって以下の結果が出力された。

[{'score': 0.08243247866630554, 'token': 2108, 'token_str': '白く', 'sequence': '夜の底が白くなった。'}, {'score': 0.05484746769070625, 'token': 28, 'token_str': 'く', 'sequence': '夜の底がくなった。'}, {'score': 0.02398742362856865, 'token': 3016, 'token_str': '黒く', 'sequence': '夜の底が黒くなった。'}, {'score': 0.02300094999372959, 'token': 7, 'token_str': 'に', 'sequence': '夜の底がになった。'}, {'score': 0.02049252949655056, 'token': 3267, 'token_str': 'わるく', 'sequence': '夜の底がわるくなった。'}

「夜の底が[MASK]なった。」の[MASK]に「白く」「く」「黒く」「に」「わるく」を埋めてきており、トップの「白く」は素晴らしいものの「く」は謎である。とりあえず公開しておくが、現時点ではtrust_remote_code=Trueが必須である。また、今後ModernBERTが、GemmaRotaryEmbeddingのみならず、FlexAttentionやinputs_embedsなどをサポートした場合、モデルを作り直す必要が生じると思われる。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?