一昨日の記事で書いたaozorabunko-cleanによる日本語DeBERTa作成プログラムを、Google Colaboratoryで動かすことを考えた。それもあえてTPUで動かす。ただ、PyTorch/XLAはGoogle Colaboratoryのサポートを打ち切ったらしいので、インストールはちょっとヤヤコシイことになっている。
!pip install torch==2.0.1 torchvision==0.15.2 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
!pip install -U transformers accelerate datasets fugashi unidic-lite
!echo 0:5::8::: | tr : '\012' | accelerate config ; cat /root/.cache/huggingface/accelerate/default_config.yaml
import os,datasets,urllib.request,jax.tools.colab_tpu
from transformers import DebertaV2TokenizerFast,DebertaV2Config,DebertaV2ForMaskedLM,DataCollatorForLanguageModeling,TrainingArguments,Trainer
from tokenizers import Tokenizer,models,pre_tokenizers,normalizers,processors,decoders,trainers
with open("train.txt","w",encoding="utf-8") as w:
d,i=datasets.load_dataset("globis-university/aozorabunko-clean"),0
for t in d["train"]:
for s in t["text"].replace("。","。\n").replace("\u3000"," ").split("\n"):
if i+len(s)<700:
print(s,end="",file=w)
i+=len(s)
else:
print("\n"+s,end="",file=w)
i=len(s)
print("",file=w)
os.system("fugashi -Owakati < train.txt > token.txt")
with urllib.request.urlopen("https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt") as r:
joyo=[chr(int(t,16)) for t in r.read().decode().strip().split("\n") if not t.startswith("#")]
spt=Tokenizer(models.Unigram())
spt.pre_tokenizer=pre_tokenizers.Sequence([pre_tokenizers.Whitespace(),pre_tokenizers.Punctuation()])
spt.normalizer=normalizers.Sequence([normalizers.Nmt(),normalizers.NFKC()])
spt.post_processor=processors.TemplateProcessing(single="[CLS] $A [SEP]",pair="[CLS] $A [SEP] $B:1 [SEP]:1",special_tokens=[("[CLS]",0),("[SEP]",2)])
spt.decoder=decoders.WordPiece(prefix="",cleanup=True)
spt.train(trainer=trainers.UnigramTrainer(vocab_size=65000,max_piece_length=4,initial_alphabet=joyo,special_tokens=["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"],unk_token="[UNK]",n_sub_iterations=2),files=["token.txt"])
spt.save("tokenizer.json")
jax.tools.colab_tpu.setup_tpu("0.1.dev20230825")
tkz=DebertaV2TokenizerFast(tokenizer_file="tokenizer.json",split_by_punct=True,do_lower_case=False,keep_accents=True,vocab_file="/dev/null",model_max_length=512)
t=tkz.convert_tokens_to_ids(["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"])
cfg=DebertaV2Config(hidden_size=768,num_hidden_layers=12,num_attention_heads=12,intermediate_size=3072,relative_attention=True,position_biased_input=False,pos_att_type=["p2c","c2p"],max_position_embeddings=tkz.model_max_length,vocab_size=len(tkz),tokenizer_class=type(tkz).__name__,bos_token_id=t[0],pad_token_id=t[1],eos_token_id=t[2])
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=8,fsdp="full_shard",fsdp_config={"xla":True},output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2)
class ReadLineDS(object):
def __init__(self,file,tokenizer):
self.tokenizer=tokenizer
with open(file,"r",encoding="utf-8") as r:
self.lines=[s.strip() for s in r if s.strip()>""]
__len__=lambda self:len(self.lines)
__getitem__=lambda self,i:self.tokenizer(self.lines[i],truncation=True,add_special_tokens=True,max_length=self.tokenizer.model_max_length-2)
trn=Trainer(args=arg,data_collator=DataCollatorForLanguageModeling(tkz),model=DebertaV2ForMaskedLM(cfg),train_dataset=ReadLineDS("train.txt",tkz))
trn.train()
trn.save_model("deberta-base-aozorabunko-clean")
tkz.save_pretrained("deberta-base-aozorabunko-clean")
from transformers import pipeline
fmp=pipeline("fill-mask","deberta-base-aozorabunko-clean")
print(fmp("夜の底が[MASK]なった。"))
頑張って書いたのだが、v2-8 TPUだと、このプログラムは200時間くらいかかってしまう。結局、私(安岡孝一)自身は、日本語DeBERTa完成まで辿り着けなかった。悲しい。