未知語から深層学習モデルを守る党。
今回は定頻度語をbyte-levelでぶっ壊します。
BPEで語彙圧縮した「低頻度語をぶっ壊す」の続きです。
byte-level BPE語彙圧縮とは?
Language Models are Unsupervised Multitask Learnersで登場した手法です。
(単純な手法なのでもしかしたらその前からあるかもしれません。)
前回紹介した語彙圧縮はUnicode文字ベースでした。
対して今回は、byteベースで圧縮しただけです。
(BPE(Byte Pair Encoding)としてはこっちの方が本来正しい。)
文字列をただbyte文字列に変換しただけでも語彙数を256*2=512種類に圧縮できます。
しかし、それだけだと単語ベースの入力よりもシーケンス長が数倍ロングになってしまうので、
BPE圧縮して語彙数とシーケンス長のバランスを取ることで、深層学習モデルに優しい入力を作ります。
とりあえず実行結果から見て掴む。
今回も前回と同様に吾輩は猫であるをBPEの学習に使います。
textの一部: 吾輩わがはいは猫である。名前はまだ無い。 どこで生れたかとんと見当けんとうがつかぬ。何でも薄暗いじめじめした所でニャ
ーニャー泣いていた事だけは記憶している。吾輩はここで始めて人間というものを見た。
分かち書きを行います。
tokensの一部:['吾輩', 'わがはい', 'は', '猫', 'で', 'ある', '。', '名前', 'は', 'まだ', '無い', '。', ' ', '\u3000', 'どこ', 'で',
'生れ', 'た', 'か', 'とんと', '見当', 'けん', 'とう', 'が', 'つか', 'ぬ', '。', '何', 'でも', '薄暗い', 'じめじめ', 'し', 'た', '所', 'で', 'ニャーニャー', '泣い', 'て', 'いた事', 'だけ', 'は', '記憶', 'し', 'て', 'いる', '。', '吾輩', 'は', 'ここ', 'で']
語彙を辞書に格納します。
'語彙':(語彙のbyte列, 出現頻度)
で格納してあります。
vocabの一部:{'吾輩': ([229, 144, 190, 232, 188, 169], 481), 'わがはい': ([227, 130, 143, 227, 129, 140, 227, 129, 175, 227, 129, 132], 1), 'は': ([227, 129, 175], 6560), '猫': ([231, 140, 171], 247), 'で': ([227, 129, 167], 3773), 'ある': ([227, 129, 130, 227, 130, 139], 1709), '。': ([227, 128, 130], 7487), '名前': ([229, 144, 141, 229, 137, 141], 47), 'まだ': ([227, 129, 190, 227, 129, 160], 123), '無い': ([231, 132, 161, 227, 129, 132], 19)}
BPEで語彙のbyte列を圧縮します。
encoded_vocabの一部:{'吾輩': ([708], 481), 'わがはい': ([-624, -625, -658, 518], 1), 'は': ([522], 6560), '猫': ([847], 247), 'で': ([534], 3773), 'ある': ([564], 1709), '。': ([519], 7487), '名前': ([1723], 47), 'まだ': ([1074], 123), '無い': ([-804, 518], 19)}
[229, 144, 190, 232, 188, 169]
が[708]
に圧縮されました。
[-624, -625, -658, 518]
のようなマイナス値は次の数値と連結していることを示しています。
(前回の@@
と同意です。)
吾輩は猫であるを圧縮後のbyte列に変換したものがこちら。
bpe_bytesの一部:[708, -624, -625, -658, 518, 522, 847, 534, 564, 519, 1723, 522, 1074, -804, 518, 519, 288, 732, 1109, 534, 2043, 533, 546, -558, 1361, -631, 1348, 1906, 1075, 528]
解凍してbyte列を文字に戻したものが以下。
(冒頭のスペースがないのは分かち書きの仕様で落ちています。)
decoded_textの一部:吾輩わがはいは猫である。名前はまだ無い。 どこで生れたかとんと見当けんとうがつかぬ。何でも薄暗いじめじめした所でニャーニャー泣いていた事だけは記憶している。吾輩はここで始めて人間というものを見た。し
byte-level BPEにより語彙数はかなり増えてしまいましたが、
以下のようにシーケンス長を減らすことができました。
シーケンス(byte)長
圧縮前:1045519 圧縮後:289129
語彙数
圧縮前:209 圧縮後:2590
分かち書きした単語での結果と、前回のUnicode文字ベースのBPEの結果はこのような感じです。
シーケンス(単語)長
圧縮前:219296 圧縮後:244948
語彙数
圧縮前:15717 圧縮後:7515
Unicode文字ベースに比べて効果的に圧縮できている気がします。
(圧縮度をコントロールするパラメーターを最適化していないので純粋な比較はできませんが。)
byteレベルBPE
前回は、subword-nmt
でBPE語彙圧縮できましたが、
byteレベルはサポートされていないので、Neural Machine Translation of Rare Words with Subword Unitsに乗っているアルゴリズムを改良して、byte-level BPEを実装しました。
コード全体は記事に最後においてあります。
コードの軽い解説
分かち書きにはJanome
を使っています。
from janome.tokenizer import Tokenizer
tokens = t.tokenize(text)
Janome
はMeCab
を100%Pythonで再実装した形態素解析機です。
Janome
のバージョンが同じであれば、どの環境でも同じように動く為オススメです。
byte_pair_encoder
でbyteレベルのBPE語彙圧縮を行なっています。
encoded_vocab, log = byte_pair_encoder(vocab)
log
は解凍に必要な圧縮ログを保管しています。
byte_pair_encoder
の中身はこのような感じ。
def byte_pair_encoder(vocab, num_merges = 30000, min_freq = 20):
vocab = copy.deepcopy(in_vocab)
log = {}
log, vocab = add_end_byte(log, vocab)
for i in range(num_merges):
pairs = get_stats(vocab)
if max(pairs.values()) <= min_freq:
break
best = max(pairs, key=pairs.get)
vocab,new_byte = merge_vocab(best, vocab, i)
log[new_byte] = best
log, vocab = add_next_byte(log, vocab)
return vocab, log
num_merges
は共通byte置き換え回数です。(前回の-s
と同意)
また、共通byteの頻度がmin_freq
以下になった場合にも圧縮を終了します。
基本的には、num_merges
に十分大きな値を指定して、min_freq
で圧縮度を調整するのがいいと思います。
(min_freq
が小さいほど、圧縮度が高い。=>シーケンス長が短くなるが、語彙数が増える。)
add_end_byte(log, vocab)
は語彙のbyte列の末尾のbyteに256を足して新規byteを定義しています。
(これで前回の</w>
と同等のことをしています。)
get_stat(vocab)
で語彙のbyte列のbigramと頻度を取得。
max(pairs, key=pairs.get)
で最頻のbigramを取得。
merge_vocab(best, vocab, i)
で最頻bigramに対して新しいbyteを割り当てています。
この繰り返しでBPEで語彙圧縮していきます。
最後にadd_next_byte(log, vocab)
で語彙のbyte列の末尾を除く数値に1を足して正負を反転しています。
(これで前回の@@
と同等のことをしています。)
byte_pair_decoder
で圧縮された語彙を解凍できます。
def decoder(log, symbols):
stock = []
for s in symbols:
pair = log.get(s)
if pair is None:
stock.append(s)
continue
stock.extend(decoder(log, pair))
return stock
def byte_pair_decoder(log, symbols):
numarray = decoder(log, symbols)
return bytes(numarray).decode("utf-8")
logは、例えば523
が(512, 423)
のbigramから成っていることを記録しています。
decoder
で再帰的にbyteをbyteのbigramに解凍しています。
おわり
BPE自体がかなり簡単なアルゴリズムなのでサクッと実装できました。
日本語などのUnicodeが基調となる言語で、byte-level BPEで圧縮したシーケンスを用いて文生成タスクを行った際に、得られた出力がきちんと解凍出来るかは謎です。。。
コード全体
Python3系で書かれています。wagahai.txt
を適当な文章に差し替えれば、コピペで動くと思います。
分かち書きにjanomeのインストールが必要です。pip install janome
import copy
import re, collections
from janome.tokenizer import Tokenizer
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, (symbols,freq) in vocab.items():
for i in range(len(symbols)-1):
pairs[symbols[i],symbols[i+1]] += freq
return pairs
def merge_vocab(pair, v_in, opr_num):
v_out = {}
new_byte = 256 * 2 + opr_num
for word,(symbols, freq) in v_in.items():
symbols = list(symbols)
for i in reversed(range(len(symbols)-1)):
if (symbols[i],symbols[i+1]) == pair:
symbols[i] = new_byte
del symbols[i+1]
v_out[word] = (symbols, freq)
return v_out, new_byte
def add_end_byte(log, vocab):
for (symbols,_) in vocab.values():
log[symbols[-1] + 256] = (symbols[-1],)
symbols[-1] += 256
return log, vocab
def add_next_byte(log, vocab):
for (symbols,_) in vocab.values():
for i in range(len(symbols)-1):
log[-1 * (symbols[i]+1)] = (symbols[i],)
symbols[i] = -1 * (symbols[i]+1)
return log, vocab
def byte_pair_encoder(vocab, num_merges = 30000, min_freq = 20):
vocab = copy.deepcopy(in_vocab)
log = {}
log, vocab = add_end_byte(log, vocab)
for i in range(num_merges):
pairs = get_stats(vocab)
if max(pairs.values()) <= min_freq:
break
best = max(pairs, key=pairs.get)
vocab,new_byte = merge_vocab(best, vocab, i)
log[new_byte] = best
log, vocab = add_next_byte(log, vocab)
return vocab, log
def decoder(log, symbols):
stock = []
for s in symbols:
pair = log.get(s)
if pair is None:
stock.append(s)
continue
stock.extend(decoder(log, pair))
return stock
def byte_pair_decoder(log, symbols):
numarray = decoder(log, symbols)
return bytes(numarray).decode("utf-8")
t = Tokenizer(wakati = True)
with open('wagahai.txt' ,'r') as f:
text = f.read().replace("\n"," ")
print("textの一部:{}\n".format(text[:100]))
tokens = t.tokenize(text)
print("tokensの一部:{}\n".format(tokens[:50]))
c = collections.Counter(tokens)
vocab = {k:(list(bytearray(k,'utf-8')),v) for k,v in c.items()}
print("vocabの一部:{}\n".format({k:v for k,v in [*vocab.items()][:10]}))
encoded_vocab, log = byte_pair_encoder(vocab)
print("encoded_vocabの一部:{}\n".format({k:v for k,v in [*encoded_vocab.items()][:10]}))
_,vocab = add_next_byte({}, vocab)
original_bytes = []
for token in tokens:
original_bytes.extend(vocab[token][0])
print("original_bytesの一部:{}\n".format(original_bytes[:30]))
bpe_bytes = []
for token in tokens:
bpe_bytes.extend(encoded_vocab[token][0])
print("bpe_bytesの一部:{}\n".format(bpe_bytes[:30]))
decoded_text = byte_pair_decoder(log, bpe_bytes)
print("decoded_textの一部:{}\n".format(decoded_text[:100]))
print("byte長の比較")
print("bpe適用前:{} bpe適用後:{}\n".format(len(original_bytes),len(bpe_bytes)))
print("語彙数の比較")
print("bpe適用前:{} bpe適用後:{}\n".format(len(set(original_bytes)),len(set(bpe_bytes))))