0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLMAdvent Calendar 2023

Day 23

SwallowのLlamaForCausalLMにおける追加トークンの単語ベクトルはどうすべきか

Posted at

昨日の記事で、Swallowへの常用漢字51字の追加に使ったresize_token_embeddingsは、追加したトークンの単語ベクトルをランダムにしか決めてくれない。異体字のinput_embeddingsならば、同じ値をコピーする手もありうるが、言語生成モデルのoutput_embeddingsでそれをやってしまうと、異体字同士で衝突が起こる。しかも追加する51字のうち、異体字があるのは「𠮟」「塡」「剝」「頰」の4字だけだ。何とか3バイトないし4バイトのUTF-8から、追加トークンの単語ベクトルをデッチあげたい。
論文を探してみたところ、Adam Ek, Jean-Philippe Bernardy『Composing Byte-Pair Encodings for Morphological Sequence Classification』が同じ悩みにブチ当たっていて、input_embeddingsにはRNNが最も良くて次がsumらしいが、output_embeddingsに果たしてRNNが使えるのか、この論文からはイマイチわからなかった。RNNで実装すると前後が面倒そうだし、とりあえずsumでやってみることにした。

#! /usr/bin/python3
import urllib.request,json,torch
from transformers import LlamaTokenizerFast,LlamaForCausalLM
with urllib.request.urlopen("https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt") as r:
  joyo=[chr(int(t,16)) for t in r.read().decode().strip().split("\n") if not t.startswith("#")]
tkz=LlamaTokenizerFast.from_pretrained("tokyotech-llm/Swallow-7b-instruct-hf",pad_token="</s>")
c={i:j[2:] for i,j in zip(joyo,tkz(joyo)["input_ids"]) if len(j)>3}
tkz.save_pretrained("exSwallow-7b-instruct-hf")
d=json.loads(tkz.backend_tokenizer.to_str())
for i,j in enumerate(c,len(tkz)):
  d["model"]["vocab"][j]=i
tkz.backend_tokenizer.from_str(json.dumps(d)).save("exSwallow-7b-instruct-hf/tokenizer.json")
mdl=LlamaForCausalLM.from_pretrained("tokyotech-llm/Swallow-7b-instruct-hf",device_map="auto")
tkz=LlamaTokenizerFast.from_pretrained("exSwallow-7b-instruct-hf",modex_max_length=mdl.config.max_position_embeddings)
e=mdl.resize_token_embeddings(len(tkz))
f=mdl.get_output_embeddings()
with torch.no_grad():
  for k,v in c.items():
    e.weight[d["model"]["vocab"][k],:]=e.weight[v,:].sum(0)
    f.weight[d["model"]["vocab"][k],:]=f.weight[v,:].sum(0)
mdl.set_input_embeddings(e)
mdl.set_output_embeddings(f)
mdl.save_pretrained("exSwallow-7b-instruct-hf")
from transformers import TextGenerationPipeline
tgn=TextGenerationPipeline(model=mdl,tokenizer=tkz,max_new_tokens=128)
nlp=lambda txt:tgn(f"以下に、あるタスクを説明する指示があります。リクエストを適切に完了するための回答を記述してください。\n\n### 指示:{txt}\n\n### 応答:",do_sample=True)[0]["generated_text"]
print(nlp("頬と頰はどちらが子供の名づけに使えますか?"))

Swallow-7b-instruct-hfに常用漢字51字のトークン追加をおこないつつ、追加学習はおこなわずに様子を見てみることにした。私(安岡孝一)の手元で出力された結果は以下のとおり。

以下に、あるタスクを説明する指示があります。リクエストを適切に完了するための回答を記述してください。

### 指示:頬と頰はどちらが子供の名づけに使えますか?

### 応答:頬は、男の子の名前としても女の子の名前としても使えます。

うーん、やはり追加学習しないと「頰」が生成されたりしないようだ。さて、どうしようかな。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?