LoginSignup
0
1

Google Colaboratoryの無料TPUで日本語DeBERTaモデルは作れるのか

Posted at

一昨日の記事で書いたaozorabunko-cleanによる日本語DeBERTa作成プログラムを、Google Colaboratoryで動かすことを考えた。それもあえてTPUで動かす。ただ、PyTorch/XLAはGoogle Colaboratoryのサポートを打ち切ったらしいので、インストールはちょっとヤヤコシイことになっている。

!pip install torch==2.0.1 torchvision==0.15.2 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
!pip install -U transformers accelerate datasets fugashi unidic-lite
!echo 0:5::8::: | tr : '\012' | accelerate config ; cat /root/.cache/huggingface/accelerate/default_config.yaml
import os,datasets,urllib.request,jax.tools.colab_tpu
from transformers import DebertaV2TokenizerFast,DebertaV2Config,DebertaV2ForMaskedLM,DataCollatorForLanguageModeling,TrainingArguments,Trainer
from tokenizers import Tokenizer,models,pre_tokenizers,normalizers,processors,decoders,trainers
with open("train.txt","w",encoding="utf-8") as w:
  d,i=datasets.load_dataset("globis-university/aozorabunko-clean"),0
  for t in d["train"]:
    for s in t["text"].replace("。","。\n").replace("\u3000"," ").split("\n"):
      if i+len(s)<700:
        print(s,end="",file=w)
        i+=len(s)
      else:
        print("\n"+s,end="",file=w)
        i=len(s)
  print("",file=w)
os.system("fugashi -Owakati < train.txt > token.txt")
with urllib.request.urlopen("https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt") as r:
  joyo=[chr(int(t,16)) for t in r.read().decode().strip().split("\n") if not t.startswith("#")]
spt=Tokenizer(models.Unigram())
spt.pre_tokenizer=pre_tokenizers.Sequence([pre_tokenizers.Whitespace(),pre_tokenizers.Punctuation()])
spt.normalizer=normalizers.Sequence([normalizers.Nmt(),normalizers.NFKC()])
spt.post_processor=processors.TemplateProcessing(single="[CLS] $A [SEP]",pair="[CLS] $A [SEP] $B:1 [SEP]:1",special_tokens=[("[CLS]",0),("[SEP]",2)])
spt.decoder=decoders.WordPiece(prefix="",cleanup=True)
spt.train(trainer=trainers.UnigramTrainer(vocab_size=65000,max_piece_length=4,initial_alphabet=joyo,special_tokens=["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"],unk_token="[UNK]",n_sub_iterations=2),files=["token.txt"])
spt.save("tokenizer.json")
jax.tools.colab_tpu.setup_tpu("0.1.dev20230825")
tkz=DebertaV2TokenizerFast(tokenizer_file="tokenizer.json",split_by_punct=True,do_lower_case=False,keep_accents=True,vocab_file="/dev/null",model_max_length=512)
t=tkz.convert_tokens_to_ids(["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"])
cfg=DebertaV2Config(hidden_size=768,num_hidden_layers=12,num_attention_heads=12,intermediate_size=3072,relative_attention=True,position_biased_input=False,pos_att_type=["p2c","c2p"],max_position_embeddings=tkz.model_max_length,vocab_size=len(tkz),tokenizer_class=type(tkz).__name__,bos_token_id=t[0],pad_token_id=t[1],eos_token_id=t[2])
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=8,fsdp="full_shard",fsdp_config={"xla":True},output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2)
class ReadLineDS(object):
  def __init__(self,file,tokenizer):
    self.tokenizer=tokenizer
    with open(file,"r",encoding="utf-8") as r:
      self.lines=[s.strip() for s in r if s.strip()>""]
  __len__=lambda self:len(self.lines)
  __getitem__=lambda self,i:self.tokenizer(self.lines[i],truncation=True,add_special_tokens=True,max_length=self.tokenizer.model_max_length-2)
trn=Trainer(args=arg,data_collator=DataCollatorForLanguageModeling(tkz),model=DebertaV2ForMaskedLM(cfg),train_dataset=ReadLineDS("train.txt",tkz))
trn.train()
trn.save_model("deberta-base-aozorabunko-clean")
tkz.save_pretrained("deberta-base-aozorabunko-clean")
from transformers import pipeline
fmp=pipeline("fill-mask","deberta-base-aozorabunko-clean")
print(fmp("夜の底が[MASK]なった。"))

頑張って書いたのだが、v2-8 TPUだと、このプログラムは200時間くらいかかってしまう。結局、私(安岡孝一)自身は、日本語DeBERTa完成まで辿り着けなかった。悲しい。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1