HuggingFaceのtransformersモジュール、めちゃくちゃ便利ですよね。HuggingFaceに公開されているモデルなどをサクっと使えてしまいます。でも、何も知識がなくても使えてしまうものなので、内部で何が起きているかが分からなくなってしまいます。
本稿では、テキスト生成タスクを自分で実装し、数式とも照らし合わせながら、generate
メソッド内ではどういう風に処理が行われているのかを理解したいと思います。なお、エンコーダやデコーダの中身(アテンション機構とか)については詳しくは触れません。アテンション機構とかについては分かりやすい記事がいっぱいあるので、そちらを参考にして頂けたらと思います。
準備 : 記号定義
はじめに、本稿で使用する記号を定義します。今回は、エンコーダもデコーダも兼ね備えたTransformerモデルを対象としますので、BERTとかは対象外です。あと、筆者は制御工学を専攻している学生なので、書き方が変だったりするかもしれませんが、ご了承ください。(制御の人って何でも数式にしたがるんです。お付き合いくださいませ。)
ベクトルや行列の形状(Tensor.shape)を$(n,m,l,...)$と表すこととします。なお、バッチ数はずっと1なので、バッチの次元については無視します。また、ベクトルや行列の要素を$x[i,j,k,...]$と表すこととします。
まず、エンコーダを$f_E$、デコーダを$f_D$と置きます。
エンコーダ$f_E$は、入力文$\mathtt{sentence}_i$を受け取ってエンコーダの状態$s_E:(n_i,d)$を出力します。なお、$n_i$は入力トークンの個数、$d$はTransformerモデル内で扱われる特徴量の次元数です。
s_E = f_E(\mathtt{sentence}_i)
デコーダ$f_D$は、出力トークン列$t_o:(n_o,)$とエンコーダの状態$s_e$を受け取ってデコーダの状態$s_D:(n_o,d)$を出力します。
s_D = f_D(t_o,s_E)
テキスト生成タスクにおいては、さらに以下のような過程を経て出力トークン列を伸長していきます。
❶ ボキャブラリの総数を$n_v$として、変換$H:(d,)\to(n_v,)$を使って次のトークンのロジット$l:(n_v,)$を求めます。
l=H(s_D[n_o-1])
❷ ロジットを基に次のトークンを求めます。
t_*=\arg\max_{t}~l[t]
❸ 次のトークンを出力トークン列に繋げます。なので、出力トークン列の次元数は$(n_0,)$から$(n_0+1,)$になります。
t_o\leftarrow[t_o,t_*]
ちなみに、テキスト生成タスクにおいては、出力トークン列の初期値は$t=[\texttt{PAD}]$です。PADは生成する最初の種みたいなものです。こいつがどんどん長くなっていって、テキストが生成されていくわけです。
モデルのダウンロード
モデルを準備します。今回使うのはコレ↓日本語の文章をうまーく要約してくれるTransformerです。Git cloneで持ってきます。
準備 : generateメソッドを使ってみる
HuggingFaceのTransformersで用意されているgenerateメソッドを使うと・・・
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from torch import Tensor
tokenizer = T5Tokenizer.from_pretrained(<path-to-model>)
model = T5ForConditionalGeneration.from_pretrained(<path-to-model>)
inputs = tokenizer("今日入った温泉はこれまでのどこの温泉よりも格別だった。とにかく気持ちいい。気持ち良すぎて、まだ出たくない。", return_tensors="pt")
outputs = model.generate(**inputs)
tokenizer.batch_decode(outputs)
['<pad> もう出たくない。気持ち良すぎてまだ出たくない。</s>']
簡単すぎる!今回はこのテキスト生成を自分で作ってみます。
生成部分を実装する
❶ エンコーダ状態を得る
今回のモデルはMT5ForConditionalGeneration
というものです。中にencoder
とdecoder
という2つがいるので、それらを引っ張り出してきて使います。エンコーダ状態$s_E$を得る関数$f_E$に相当するものを作ります。数式でいうと$s_E = f_E(\mathtt{sentence}_i)$です。
def get_encoder_state(sentence:str)->Tensor:
inputs = tokenizer(sentence, return_tensors="pt")
oe = model.encoder.forward(**inputs)
se = oe["last_hidden_state"]
return se
se = get_encoder_state("今日入った温泉はこれまでのどこの温泉よりも格別だった。とにかく気持ちいい。気持ち良すぎて、まだ出たくない。")
se.shape
torch.Size([1, 27, 768])
トークン数$n_i=27$、特徴量の次元数$d=768$です。
❷ 初期の出力トークン列を作る
$t_o$の初期状態$t_o=[\mathtt{PAD}]$を作っておきます。
to = torch.tensor([[tokenizer.pad_token_id]])
❸ デコーダ状態を得る
$f_D$に当たるものを作り、デコーダ状態$s_D$を得ます。数式でいうと$s_D = f_D(t_o,s_E)$です。
od = model.decoder.forward(input_ids=to,encoder_hidden_states=se)
sd = od.last_hidden_state
sd.shape
torch.Size([1, 1, 768])
まだトークン数は$n_o=1$です。
❹ ロジットを得る
先ほど変換$H$と紹介したものを行い、次のトークンに関するロジット$l$を得ます。数式でいうと$l=H(s_D[n_o-1])$です。ちなみに、デコーダ状態の一番最後以外の要素$s_D[0]$から$s_D[n_o-2]$は使いません。
l = model.lm_head(sd[0, -1, :])
l.shape
torch.Size([32128])
ボキャブラリの総数は$n_v=321238$です。
❺ 次のトークンを確定させる
$t_*=\arg\max_{t}~l[t]$をやりまーす(適当)
t_star = l.argmax()
❻ 出力トークン列を伸長する
$t_o\leftarrow[t_o,t_*]$をします。
to = torch.cat((to, t_star[None, None]), dim=-1)
❸〜❻を繰り返す
生成タスクでは、以上の処理を繰り返します。いつまで繰り返すかというと、次のトークンが終了を意味する"</s>"
になるまでです。
MAX_ITER_COUNT = 30
for iter_count in range(MAX_ITER_COUNT):
od = model.decoder.forward(input_ids=to,encoder_hidden_states=se)
sd = od.last_hidden_state
l = model.lm_head(sd[0, -1, :])
t_star = l.argmax()
print(tokenizer.convert_ids_to_tokens([t_star])[0],end="")
to = torch.cat((to, t_star[None, None]), dim=-1)
sentence = tokenizer.batch_decode(to)[0]
if t_star == tokenizer.eos_token_id:
print()
print("Done.")
break
print(sentence)
実験 : 生成してみる
以上のコードを動かすと、以下のような文が得られました。
もう出たくない。気持ち良すぎてまだ出たくない。
準備として試したmodel.generate
メソッドと同じ結果となりました。今回のモデルだと(多分)生成結果は同じになるのですが、中には「ビームサーチ法」などといって、次のトークンを求めるのに複雑な処理をするアルゴリズムもあり、そういうやつとは同じ結果が得られないと考えられます。
応用 : 文を混ぜる!?!?
2つの文のエンコーダ状態$s_E$を線形結合することで、文同士を混ぜることもできます。
A : "今日入った温泉はこれまでのどこの温泉よりも格別だった。とにかく気持ちいい。気持ち良すぎて、まだ出たくない。"
B : "寒くて血が凍りそうだ...あれ、「血が凍る」の意味、ちょっと違うかな?"
これらの2つの文を混ぜると...
"寒いのに、まだ風呂に浸かってる?"
面白く無いですか??あとは
A : "もしも引っ越しをせず、この小学校に通って、この中学校に通って、いつもこの道を通って駅に行って、天気のいい日にはこの公園に散歩に来て、お腹が空いたらあのラーメン屋に寄って・・・この街で生活していたら、どんな趣味を持っただろうか。どんな性格になっていただろうか。恋人はできただろうか。"
B : "遠くから、列車の音が聞こえてくる。最終列車が来た。そして、誰を乗せも降りもしないまま明日へと去っていった。",
↓
"「君はどこへ行く?」 「君はどこへ行く?」"
ふ・・・深すぎる!!!すげえええええ
また後ほど紹介します。