0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

富岳のPyTorch-1.7.0とaozorabunko-cleanで作る日本語DeBERTaモデル

Posted at

昨日の記事で書いたaozorabunko-cleanによる日本語DeBERTa作成プログラムを、富岳のHorovod with PyTorch-1.7.0に移植してみた。

#! /bin/bash
#PJM -L rscgrp=small
#PJM -L elapse=6:00:00
#PJM -L node=12x12:torus
#PJM -j
#PJM -S

G=`id | sed 's/^.*gid=[0-9]*(\([^)]*\)).*$/\1/'`
set `ls -d /vol*/$G /vol*/data/$G` $HOME
export PYTHONUSERBASE=$1/deberta-aozora
export PATH=/home/apps/oss/PyTorch-1.7.0/bin:$PYTHONUSERBASE/bin:$PATH
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/home/apps/oss/PyTorch-1.7.0/lib64
export HF_HOME=$PYTHONUSERBASE
export TMPDIR=$PYTHONUSERBASE/tmp
mkdir -p $TMPDIR
pip3.8 install transformers==4.28.1 tokenizers==0.13.3 protobuf==3.20.3 accelerate==0.20.3 --user
pip3.8 install -U tqdm packaging typing_extensions datasets fugashi unidic-lite --user
ln -s $TMPDIR/deberta-base-aozorabunko-clean .
cd $TMPDIR

P=dataset.$PJM_JOBID.$$.py
cat << 'EOF' > $P
import datasets
with open("train.txt","w",encoding="utf-8") as w:
  d,i=datasets.load_dataset("globis-university/aozorabunko-clean"),0
  for t in d["train"]:
    for s in t["text"].replace("。","。\n").replace("\u3000"," ").split("\n"):
      if i+len(s)<700:
        print(s,end="",file=w)
        i+=len(s)
      else:
        print("\n"+s,end="",file=w)
        i=len(s)
  print("",file=w)
EOF
python3.8 $P

set ${PJM_NODE-1} `wc train.txt` 0
split -d -a `expr $1 : '.*'` -l `expr '(' $2 + $1 - 1 ')' / $1` train.txt train.$PJM_JOBID.
P=fugashi.$PJM_JOBID.$$.sh
cat << 'EOF' > $P
#! /bin/bash
F=`ls -1 train.$PJM_JOBID.*${PMIX_RANK-0} | head -1`
fugashi -Owakati < $F > token`expr $F : 'train\(.*\)$'`
EOF
chmod a+x $P
env LD_PRELOAD=libtcmalloc.so mpirun -np $1 bash $P
cat token.$PJM_JOBID.* > token.txt

P=tokenizer.$PJM_JOBID.$$.py
cat << 'EOF' > $P
import urllib.request
from tokenizers import Tokenizer,models,pre_tokenizers,normalizers,processors,decoders,trainers
with urllib.request.urlopen("https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt") as r:
  joyo=[chr(int(t,16)) for t in r.read().decode().strip().split("\n") if not t.startswith("#")]
spt=Tokenizer(models.Unigram())
spt.pre_tokenizer=pre_tokenizers.Sequence([pre_tokenizers.Whitespace(),pre_tokenizers.Punctuation()])
spt.normalizer=normalizers.Sequence([normalizers.Nmt(),normalizers.NFKC()])
spt.post_processor=processors.TemplateProcessing(single="[CLS] $A [SEP]",pair="[CLS] $A [SEP] $B:1 [SEP]:1",special_tokens=[("[CLS]",0),("[SEP]",2)])
spt.decoder=decoders.WordPiece(prefix="",cleanup=True)
spt.train(trainer=trainers.UnigramTrainer(vocab_size=65000,max_piece_length=4,initial_alphabet=joyo,special_tokens=["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"],unk_token="[UNK]",n_sub_iterations=2),files=["token.txt"])
spt.save("tokenizer.json")
EOF
python3.8 $P

P=trainer.$PJM_JOBID.$$.py
cat << 'EOF' > $P
from transformers import DebertaV2TokenizerFast,DebertaV2Config,DebertaV2ForMaskedLM,DataCollatorForLanguageModeling,TrainingArguments,Trainer
from torch.utils.data.distributed import DistributedSampler
import horovod.torch as hvd
hvd.init()
tkz=DebertaV2TokenizerFast(tokenizer_file="tokenizer.json",split_by_punct=True,do_lower_case=False,keep_accents=True,vocab_file="/dev/null",model_max_length=512)
t=tkz.convert_tokens_to_ids(["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"])
cfg=DebertaV2Config(hidden_size=768,num_hidden_layers=12,num_attention_heads=12,intermediate_size=3072,relative_attention=True,position_biased_input=False,pos_att_type=["p2c","c2p"],max_position_embeddings=tkz.model_max_length,vocab_size=len(tkz),tokenizer_class=type(tkz).__name__,bos_token_id=t[0],pad_token_id=t[1],eos_token_id=t[2])
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=8,output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2)
class ReadLineDS(object):
  def __init__(self,file,tokenizer):
    self.tokenizer=tokenizer
    with open(file,"r",encoding="utf-8") as r:
      self.lines=[s.strip() for s in r if s.strip()>""]
  __len__=lambda self:len(self.lines)
  __getitem__=lambda self,i:self.tokenizer(self.lines[i],truncation=True,add_special_tokens=True,max_length=self.tokenizer.model_max_length-2)
mdl=DebertaV2ForMaskedLM(cfg)
trn=Trainer(args=arg,data_collator=DataCollatorForLanguageModeling(tkz),model=mdl,train_dataset=ReadLineDS("train.txt",tkz))
trn.create_optimizer()
trn.optimizer=hvd.DistributedOptimizer(trn.optimizer,mdl.named_parameters())
trn._get_train_sampler=lambda:DistributedSampler(trn.train_dataset,num_replicas=hvd.size(),rank=hvd.rank())
trn._get_eval_sampler=lambda x:DistributedSampler(x,num_replicas=hvd.size(),rank=hvd.rank())
hvd.broadcast_parameters(mdl.state_dict(),root_rank=0)
trn.train()
if hvd.rank()==0:
  trn.save_model("deberta-base-aozorabunko-clean")
  tkz.save_pretrained("deberta-base-aozorabunko-clean")
EOF
env LD_PRELOAD=libtcmalloc.so mpirun -np ${PJM_NODE-1} python3.8 $P

P=fillmask.$PJM_JOBID.$$.py
cat << 'EOF' > $P
from transformers import pipeline
fmp=pipeline("fill-mask","deberta-base-aozorabunko-clean")
print(fmp("夜の底が[MASK]なった。"))
EOF
python3.8 $P

頑張って並列化したのだが、このプログラムを#PJM -L node=16x24:torusつまり384ノード(富岳1ラック)で走らせると2時間40分ほどかかってしまう。384ノード×2時間40分=1024ノード時間なので、富岳のファーストタッチオプションに収まらない。いくつか試してみたところ、#PJM -L node=12x12:torusつまり144ノードが5時間半程度とイイセンで、NVIDIA A100-SXM4-40GB×4とほぼ同等だったりする。日本語DeBERTaモデルの作成時間は、元データの行数に比例するので、このプログラムでは226793936字を338727行に詰め込んだ上で、各行を各ノードにバラまいてるのだけど、そのあたりうまくいってるのかな。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?