LoginSignup
9
2

More than 3 years have passed since last update.

定頻度語を(byteレベルで)ぶっ壊す

Last updated at Posted at 2019-08-01

未知語から深層学習モデルを守る党。

今回は定頻度語をbyte-levelでぶっ壊します。

BPEで語彙圧縮した「低頻度語をぶっ壊す」の続きです。

byte-level BPE語彙圧縮とは?

Language Models are Unsupervised Multitask Learnersで登場した手法です。
(単純な手法なのでもしかしたらその前からあるかもしれません。)

前回紹介した語彙圧縮はUnicode文字ベースでした。
対して今回は、byteベースで圧縮しただけです。
(BPE(Byte Pair Encoding)としてはこっちの方が本来正しい。)

文字列をただbyte文字列に変換しただけでも語彙数を256*2=512種類に圧縮できます。

しかし、それだけだと単語ベースの入力よりもシーケンス長が数倍ロングになってしまうので、
BPE圧縮して語彙数とシーケンス長のバランスを取ることで、深層学習モデルに優しい入力を作ります。

とりあえず実行結果から見て掴む。

今回も前回と同様に吾輩は猫であるをBPEの学習に使います。

textの一部: 吾輩わがはいは猫である。名前はまだ無い。  どこで生れたかとんと見当けんとうがつかぬ。何でも薄暗いじめじめした所でニャ
ーニャー泣いていた事だけは記憶している。吾輩はここで始めて人間というものを見た。

分かち書きを行います。

tokensの一部:['吾輩', 'わがはい', 'は', '猫', 'で', 'ある', '。', '名前', 'は', 'まだ', '無い', '。', ' ', '\u3000', 'どこ', 'で',
 '生れ', 'た', 'か', 'とんと', '見当', 'けん', 'とう', 'が', 'つか', 'ぬ', '。', '何', 'でも', '薄暗い', 'じめじめ', 'し', 'た', '所', 'で', 'ニャーニャー', '泣い', 'て', 'いた事', 'だけ', 'は', '記憶', 'し', 'て', 'いる', '。', '吾輩', 'は', 'ここ', 'で']

語彙を辞書に格納します。
'語彙':(語彙のbyte列, 出現頻度)で格納してあります。

vocabの一部:{'吾輩': ([229, 144, 190, 232, 188, 169], 481), 'わがはい': ([227, 130, 143, 227, 129, 140, 227, 129, 175, 227, 129, 132], 1), 'は': ([227, 129, 175], 6560), '猫': ([231, 140, 171], 247), 'で': ([227, 129, 167], 3773), 'ある': ([227, 129, 130, 227, 130, 139], 1709), '。': ([227, 128, 130], 7487), '名前': ([229, 144, 141, 229, 137, 141], 47), 'まだ': ([227, 129, 190, 227, 129, 160], 123), '無い': ([231, 132, 161, 227, 129, 132], 19)}

BPEで語彙のbyte列を圧縮します。

encoded_vocabの一部:{'吾輩': ([708], 481), 'わがはい': ([-624, -625, -658, 518], 1), 'は': ([522], 6560), '猫': ([847], 247), 'で': ([534], 3773), 'ある': ([564], 1709), '。': ([519], 7487), '名前': ([1723], 47), 'まだ': ([1074], 123), '無い': ([-804, 518], 19)}

[229, 144, 190, 232, 188, 169][708]に圧縮されました。
[-624, -625, -658, 518]のようなマイナス値は次の数値と連結していることを示しています。
(前回の@@と同意です。)

吾輩は猫であるを圧縮後のbyte列に変換したものがこちら。

bpe_bytesの一部:[708, -624, -625, -658, 518, 522, 847, 534, 564, 519, 1723, 522, 1074, -804, 518, 519, 288, 732, 1109, 534, 2043, 533, 546, -558, 1361, -631, 1348, 1906, 1075, 528]

解凍してbyte列を文字に戻したものが以下。
(冒頭のスペースがないのは分かち書きの仕様で落ちています。)

decoded_textの一部:吾輩わがはいは猫である。名前はまだ無い。  どこで生れたかとんと見当けんとうがつかぬ。何でも薄暗いじめじめした所でニャーニャー泣いていた事だけは記憶している。吾輩はここで始めて人間というものを見た。し

byte-level BPEにより語彙数はかなり増えてしまいましたが、
以下のようにシーケンス長を減らすことができました。

シーケンス(byte)長
圧縮前:1045519 圧縮後:289129

語彙数
圧縮前:209 圧縮後:2590

分かち書きした単語での結果と、前回のUnicode文字ベースのBPEの結果はこのような感じです。

シーケンス(単語)長
圧縮前:219296 圧縮後:244948

語彙数
圧縮前:15717 圧縮後:7515

Unicode文字ベースに比べて効果的に圧縮できている気がします。
(圧縮度をコントロールするパラメーターを最適化していないので純粋な比較はできませんが。)

byteレベルBPE

前回は、subword-nmtでBPE語彙圧縮できましたが、
byteレベルはサポートされていないので、Neural Machine Translation of Rare Words with Subword Unitsに乗っているアルゴリズムを改良して、byte-level BPEを実装しました。
コード全体は記事に最後においてあります。

コードの軽い解説

分かち書きにはJanomeを使っています。

from janome.tokenizer import Tokenizer
tokens = t.tokenize(text)

JanomeMeCab100%Pythonで再実装した形態素解析機です。
Janomeのバージョンが同じであれば、どの環境でも同じように動く為オススメです。

byte_pair_encoderでbyteレベルのBPE語彙圧縮を行なっています。

encoded_vocab, log = byte_pair_encoder(vocab)

logは解凍に必要な圧縮ログを保管しています。

byte_pair_encoderの中身はこのような感じ。

def byte_pair_encoder(vocab, num_merges = 30000, min_freq = 20):
    vocab = copy.deepcopy(in_vocab)

    log = {}
    log, vocab = add_end_byte(log, vocab)

    for i in range(num_merges):
        pairs = get_stats(vocab)
        if max(pairs.values()) <= min_freq:
            break
        best = max(pairs, key=pairs.get)
        vocab,new_byte = merge_vocab(best, vocab, i)
        log[new_byte] = best

    log, vocab = add_next_byte(log, vocab)
    return vocab, log

num_mergesは共通byte置き換え回数です。(前回の-sと同意)
また、共通byteの頻度がmin_freq以下になった場合にも圧縮を終了します。
基本的には、num_mergesに十分大きな値を指定して、min_freqで圧縮度を調整するのがいいと思います。
(min_freqが小さいほど、圧縮度が高い。=>シーケンス長が短くなるが、語彙数が増える。)

add_end_byte(log, vocab)は語彙のbyte列の末尾のbyteに256を足して新規byteを定義しています。
(これで前回の</w>と同等のことをしています。)

get_stat(vocab)で語彙のbyte列のbigramと頻度を取得。
max(pairs, key=pairs.get)で最頻のbigramを取得。
merge_vocab(best, vocab, i)で最頻bigramに対して新しいbyteを割り当てています。
この繰り返しでBPEで語彙圧縮していきます。

最後にadd_next_byte(log, vocab)で語彙のbyte列の末尾を除く数値に1を足して正負を反転しています。
(これで前回の@@と同等のことをしています。)

byte_pair_decoderで圧縮された語彙を解凍できます。

def decoder(log, symbols):
    stock = []
    for s in symbols:
        pair = log.get(s)
        if pair is None:
            stock.append(s)
            continue
        stock.extend(decoder(log, pair))
    return stock

def byte_pair_decoder(log, symbols):
    numarray = decoder(log, symbols)
    return bytes(numarray).decode("utf-8")

logは、例えば523(512, 423)のbigramから成っていることを記録しています。
decoderで再帰的にbyteをbyteのbigramに解凍しています。

おわり

BPE自体がかなり簡単なアルゴリズムなのでサクッと実装できました。

日本語などのUnicodeが基調となる言語で、byte-level BPEで圧縮したシーケンスを用いて文生成タスクを行った際に、得られた出力がきちんと解凍出来るかは謎です。。。

コード全体

Python3系で書かれています。wagahai.txtを適当な文章に差し替えれば、コピペで動くと思います。

分かち書きにjanomeのインストールが必要です。pip install janome

byte_level_BPE.py
import copy
import re, collections
from janome.tokenizer import Tokenizer

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, (symbols,freq) in vocab.items():
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in, opr_num):
    v_out = {}
    new_byte = 256 * 2 + opr_num
    for word,(symbols, freq) in v_in.items():
        symbols = list(symbols)
        for i in reversed(range(len(symbols)-1)):
            if (symbols[i],symbols[i+1]) == pair:
                symbols[i] = new_byte
                del symbols[i+1]
        v_out[word] = (symbols, freq)
    return v_out, new_byte

def add_end_byte(log, vocab):
    for (symbols,_) in vocab.values():
        log[symbols[-1] + 256] = (symbols[-1],)
        symbols[-1] += 256
    return log, vocab

def add_next_byte(log, vocab):
    for (symbols,_) in vocab.values():
        for i in range(len(symbols)-1):
            log[-1 * (symbols[i]+1)] = (symbols[i],)
            symbols[i] = -1 * (symbols[i]+1)
    return log, vocab

def byte_pair_encoder(vocab, num_merges = 30000, min_freq = 20):
    vocab = copy.deepcopy(in_vocab)

    log = {}
    log, vocab = add_end_byte(log, vocab)

    for i in range(num_merges):
        pairs = get_stats(vocab)
        if max(pairs.values()) <= min_freq:
            break
        best = max(pairs, key=pairs.get)
        vocab,new_byte = merge_vocab(best, vocab, i)
        log[new_byte] = best

    log, vocab = add_next_byte(log, vocab)
    return vocab, log

def decoder(log, symbols):
    stock = []
    for s in symbols:
        pair = log.get(s)
        if pair is None:
            stock.append(s)
            continue
        stock.extend(decoder(log, pair))
    return stock

def byte_pair_decoder(log, symbols):
    numarray = decoder(log, symbols)
    return bytes(numarray).decode("utf-8")

t = Tokenizer(wakati = True)

with open('wagahai.txt' ,'r') as f:
    text = f.read().replace("\n"," ")
print("textの一部:{}\n".format(text[:100]))

tokens = t.tokenize(text)
print("tokensの一部:{}\n".format(tokens[:50]))

c = collections.Counter(tokens)
vocab = {k:(list(bytearray(k,'utf-8')),v) for k,v in c.items()}
print("vocabの一部:{}\n".format({k:v for k,v in [*vocab.items()][:10]}))

encoded_vocab, log = byte_pair_encoder(vocab)
print("encoded_vocabの一部:{}\n".format({k:v for k,v in [*encoded_vocab.items()][:10]}))

_,vocab = add_next_byte({}, vocab)

original_bytes = []
for token in tokens:
    original_bytes.extend(vocab[token][0])

print("original_bytesの一部:{}\n".format(original_bytes[:30]))

bpe_bytes = []
for token in tokens:
    bpe_bytes.extend(encoded_vocab[token][0])

print("bpe_bytesの一部:{}\n".format(bpe_bytes[:30]))

decoded_text = byte_pair_decoder(log, bpe_bytes)
print("decoded_textの一部:{}\n".format(decoded_text[:100]))

print("byte長の比較")
print("bpe適用前:{} bpe適用後:{}\n".format(len(original_bytes),len(bpe_bytes)))

print("語彙数の比較")
print("bpe適用前:{} bpe適用後:{}\n".format(len(set(original_bytes)),len(set(bpe_bytes))))
9
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
2