昨日の記事で使ったWikiText-JAのtrain.txt
を確かめていたところ、「か゚行」という文字列が目にとまった。しまった、私(安岡孝一)も他人の事は言えない、Split(Regex("."),"isolated")
だと、「か゚」が2トークンに分かれてしまう。「が」や「ぱ」が1トークンなのに、「か゚」が2トークンなのはマズイ。半濁点と濁点だけでも直してみよう。
!pip install transformers accelerate
import re,urllib.request,unicodedata
from transformers import BertTokenizerFast,DebertaV2Config,DebertaV2ForMaskedLM,DataCollatorForLanguageModeling,TrainingArguments,Trainer
from tokenizers import pre_tokenizers,normalizers,Regex
url="http://www.lsta.media.kyoto-u.ac.jp/resource/data/wikitext-ja/"
p,i={"*447*":"\u30ab\u309a","*7003*":"","*8050*":"","*10789*":""},0
with open("train.txt","w",encoding="utf-8") as w:
for t in ["Featured_Contents.txt","Good_Contents.txt"]:
with urllib.request.urlopen(url+"Exception_"+t[0]+".txt") as r:
e={"*"+s[0:-1]+"*":s[-1] if unicodedata.name(s[-1],False) else p["*"+s[0:-1]+"*"] for s in r.read().decode("utf-8").split("\n") if s.strip()>""}
with urllib.request.urlopen(url+t) as r:
for s in r.read().decode("utf-8").replace("。","。\n").split("\n"):
for t in re.findall(r"\*[1-9][0-9]*\*",s):
s=s.replace(t,e[t])
if i+len(s)<128:
print(s,end="",file=w)
i+=len(s)
else:
print("\n"+s,end="",file=w)
i=len(s)
print("",file=w)
n=normalizers.Sequence([normalizers.Nmt(),normalizers.NFKC()])
p=pre_tokenizers.Sequence([pre_tokenizers.Whitespace(),pre_tokenizers.Split(Regex(".[\u3099\u309a]?"),"isolated")])
with open("train.txt","r",encoding="utf-8") as r:
v=set(c for c,_ in p.pre_tokenize_str(n.normalize_str(r.read())) if not c.isspace())
with urllib.request.urlopen("https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt") as r:
_=[v.add(chr(int(t,16))) for t in r.read().decode().strip().split("\n") if not t.startswith("#")]
with open("vocab.txt","w",encoding="utf-8") as w:
print("\n".join(["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"]+sorted(v)),file=w)
tkz=BertTokenizerFast(vocab_file="vocab.txt",never_split=["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"],do_lower_case=False,strip_accents=False,tokenize_chinese_chars=True,model_max_length=128)
tkz.backend_tokenizer.normalizer=n
tkz.backend_tokenizer.pre_tokenizer=p
tkz.backend_tokenizer.decoder.prefix=tkz.backend_tokenizer.model.continuing_subword_prefix=""
cfg=DebertaV2Config(hidden_size=256,num_hidden_layers=12,num_attention_heads=4,intermediate_size=768,relative_attention=True,position_biased_input=False,pos_att_type=["p2c","c2p"],max_position_embeddings=tkz.model_max_length,vocab_size=len(tkz),tokenizer_class=type(tkz).__name__,bos_token_id=tkz.cls_token_id,pad_token_id=tkz.pad_token_id,eos_token_id=tkz.sep_token_id)
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=64,output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2)
class ReadLineDS(object):
def __init__(self,file,tokenizer):
self.tokenizer=tokenizer
with open(file,"r",encoding="utf-8") as r:
self.lines=[s.strip() for s in r if s.strip()!=""]
__len__=lambda self:len(self.lines)
__getitem__=lambda self,i:self.tokenizer(self.lines[i],truncation=True,add_special_tokens=True,max_length=self.tokenizer.model_max_length-2)
trn=Trainer(args=arg,data_collator=DataCollatorForLanguageModeling(tkz),model=DebertaV2ForMaskedLM(cfg),train_dataset=ReadLineDS("train.txt",tkz))
trn.train()
trn.save_model("deberta-mini-wikitext-ja")
tkz.save_pretrained("deberta-mini-wikitext-ja")
from transformers import pipeline
fmp=pipeline("fill-mask","deberta-mini-wikitext-ja")
print(fmp("酸素ボ[MASK]ベを充塡する。"))
Google ColaboratoryのTesla T4だと、私の手元では1時間弱で日本語DeBERTaミニモデルが完成し、以下の結果が出力された。
[{'score': 0.1691729724407196, 'token': 950, 'token_str': 'ル', 'sequence': '酸素ボルベを充塡する。'}, {'score': 0.10775759816169739, 'token': 963, 'token_str': 'ー', 'sequence': '酸素ボーベを充塡する。'}, {'score': 0.10529794543981552, 'token': 949, 'token_str': 'リ', 'sequence': '酸素ボリベを充塡する。'}, {'score': 0.0870039239525795, 'token': 958, 'token_str': 'ン', 'sequence': '酸素ボンベを充塡する。'}, {'score': 0.05879103019833565, 'token': 900, 'token_str': 'ス', 'sequence': '酸素ボスベを充塡する。'}]
昨日に較べて、微妙に結果が悪くなっているものの、token_idは3つほど大きくなっている。vocab.txt
を眺めてみたところ「か゚」「カ゚」「キ゚」が増えているようだ。まあ、ちゃんと動いてるってことかな。