2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

huggingface, BertModelの forwardに入るargumentsをよくわかっていなかったので調べてみた

Last updated at Posted at 2022-02-11

huggingface, BertModelの forwardに入るargumentsをよくわかっていなかったので調べてみた

class BertModel(BertPreTrainedModel):

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            past_key_values=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
        ):

input_idsは単語のtokenだとして、残り奴らがよくわからなかったので調べながらメモする。

attention_maskとは

batch処理する際には必ず必要となるものでこのattention_maskが1となるindexのtokenにモデルは注目する。
sequence_a, sequence_bの単語数がそれぞれ異なる時、例えば、

sequence_a = "This is a short sequence."
# tokeninzed
[[101, 1188, 1110, 170, 1603, 4954, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
sequence_b = "This is a rather long sequence. It is at least longer than the sequence A."
# tokeninzed
[101, 1188, 1110, 170, 1897, 1263, 4954, 119, 1135, 1110, 1120, 1655, 2039, 1190, 1103, 4954, 138, 119, 102]]

この時、単語数と一致しないのは、前後にtokenがつくからである。
sequence_a
'[CLS] HuggingFace is based in NYC [SEP]'

tenosrにa, bを入れるためには、aはpaddingを入れる必要があるし、bはある長さ以上はtruncationしないといけなくなる。この0をpaddingに入れた部分にはbinary tensorとしてpaddingを入れた番号に対して処理しないようにindicaterを入れる、これがattention_maskである。

token_type_idsとは何か

BertModelに入力する input_ids が1つ以上のsequenceの入力となる場合に区別するためのidである。

例えば、質問回答のようなタスクには二つの異なる文章のsequencesを一つの input_idsとして入力する必要がある、
sequence_a = "HuggingFace is based in NYC"
sequence_b = "Where is HuggingFace based?"
のようなもの

この二つを区別するものとして token_type_ids は存在して、sequence_aは0、 sequence_bは1として格納された torch.Tensorである

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
sequence_a = "HuggingFace is based in NYC"
sequence_b = "Where is HuggingFace based?"

encoded_dict = tokenizer(sequence_a, sequence_b)


>>> encoded_dict['token_type_ids']
>>> [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]

のような結果になる。

この時、tokenizerは[CLS]、[SEP]に当たるtokenを2つのsequenceの間に挿入する。

>>>> decoded = tokenizer.decode(encoded_dict["input_ids"])
[CLS] HuggingFace is based in NYC [SEP] Where is HuggingFace based? [SEP]

position_idsとは

Optionalである。
RNNなどの場合にはRecurrentに処理することでtokenの時系列を考慮していたが、各tokenの位置を意識していない。tokenのリストの中で各tokenの位置を特定するために、position ID(position_ids)がモデルによって使用される。
下のようにself.position_embedding_type == "absolute"ではnn.Embeddingが用いられている。
config.max_position_embeddings => tokenizerでのmax_lenのことそれぞれ独立のconfig.hidden_sizeのEmbedded vectorとなる

コードからposition_idsに関連した処理の部分

transformers/models/bert/modeling_bert.py
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""
...
    def __init__(self, config):
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
...
    def forward(..., position_ids, ...):
       if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings

head_mask

(torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional)
self-attention modulesで選択されたheadsをnullにできる。1でnot mask, 0でmask
ざっくり調べてもあまり使われてなさそう。

inputs_embeds とは

input_idsinputs_embedsのどちらかをmodelに入れる必要があり、inputs_embedsを入れる時にはinput_idsがモデルの内部の embedding lookup matrix ではなく自分で定義したassociated vectorに変換したものを用いたい時である。
より具体的には BertEmbeddings内のforward処理のコア部分を抜き出すと
nn.Embeddingを用いるか、そのほかの処理を用いるかということになる。

## __init__()内部
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
## forward()内部
token_type_embeddings = self.token_type_embeddings(token_type_ids)
if inputs_embeds is None:
    inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
    position_embeddings = self.position_embeddings(position_ids)
    embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings

参考

huggingface Glossary
transformerのversion '4.17.0.dev0'

以降のargumentsはoptionalなので調べるのはまた今度にする。
もし解説記事があれば教えていただけるとありがたいです。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?