4
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

BERTにラベルの埋め込みを追加する方法

Last updated at Posted at 2023-03-02

はじめに

BERTを使う時にテキストデータと一緒にラベルデータを使いたいときはありませんか?
例えば以下のようなテキストデータから書いた人の年齢を当てるタスクを解くとします.
image.png
この時にモデルに対してそれぞれの文節がどのような特徴を持つかを明示する,つまり以下のような文節の種類を表すラベルデータを一緒にモデルに入力出来れば精度が上がりそうな気がしますね.(実際に上がるかはかなりタスク依存)
image.png
huggingfaceのBERTではこれが簡単に実装できるようになっているため,今回の記事ではその方法を紹介していきたいと思います.

例以外にも要約タスクの共参照関係や係り受けの関係性を明示するなどいろいろなタスクに用いられます.

手法

インポート

import torch
import torch.nn as nn
from transformers import BertModel
from transformers import BertJapaneseTokenizer

データの用意

簡単のためにバッチなどは一旦置いておいて,テキストデータとラベルデータを簡単に用意します.
ここでラベルは以下のように定義します.

ラベルID ラベル 説明
0 挨拶 こんにちは,おはようなどの挨拶
1 事実説明 現在位置を言うなど事実をただ述べているだけ
2 自己開示 自身の趣味や職業などの開示
sentences = [
    ("こんにちは[SEP]私は小学生の頃にマリカDSをプレイしていました。[SEP]今は新卒2年目の会社員です。"),
    ("初めまして![SEP]今日はいい天気ですね。[SEP]天気が良い日は散歩がしたくなります。"),
]

# 各文節に対応するラベルデータ
labels =[
    [0, 2, 2],
    [0, 1, 2],
]

# 教師データ(年齢)
ages = [25, 43]

埋め込みの獲得

テキストデータ

huggingface transformersでBERTを使っている方にはおなじみのBertJapaneseTokenizerを使ってトークナイズします.

tokenizer = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-whole-word-masking")
tokenized_sentences = tokenizer.encode_plus(sentences, return_tensors="pt")

通常はここでBERTに入力してしまうのですが,今回はラベル埋め込みを加算する関係上,自前で埋め込みを作ります.
ここではわかりやすさのために事前学習済みモデルcl-tohoku/bert-base-japanese-whole-word-maskingの語彙サイズ(32000)と埋め込みサイズ(768)を設定していますが,BertConfig等を使ってconfig.vocab_sizeのように指定してあげるとよりスマートです.
bert_embedingにトークナイズしたテキストデータを与えてテキストデータの埋め込みを作ります.

bert_embeding = nn.Embeding(32000, 768)
text_embeds = bert_embeding(tokenized_sentences["input_ids"])

pytorchのEmbedingsが分からない方はこちらの分かりやすい記事がおすすめです.

ラベル

続いてラベルの埋め込みです.こちらは元からラベルIDの列として定義しているので先程と同じ手順で埋め込みを計算します.
新しくnn.Embedingsのインスタンスを作成してあげないとテキストデータの埋め込みと競合してしまうので,新しく定義し直します.
ただしこちらの語彙サイズは 3とします.これは用いるラベルの種類の数です.埋め込みサイズは先程と同様に768とします.

label_embeding = nn.Embeding(3, 768)
label_embeds = label_embeding(tokenized_sentences)

埋め込みの加算

単純ですがここが重要な部分です.
今回やりたい「BERTに対してテキストデータと対応するラベルデータを一緒に入力する」というのはテキストデータの埋め込みとラベルの埋め込みを単純に加算することで実現できます.

inputs_embeds = text_embeds + label_embeds

BERTへの入力

ここまでで入力データを埋め込みに変換できたので,最後にモデルに入力して予測を行いたいと思います.
BertForSequenceClassificationnum_labels=1とすると回帰予測が出来るので今回はこれを使います.
そしてinputs_embeds引数に先程作った埋め込みのinputs_embedsを渡します.
これでテキストデータとラベルデータを考慮した予測を行うことが出来ます.
ただし,今回使っている東北大BERTはそんな事前学習は一切されていないので,性能が上がるどころか下がってしまうと思います.(未検証)
そのため,テキストデータとラベルデータを考慮したモデルを作りたい場合はファインチューニングをすることを推奨します.

model = BertForSequenceClassification.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking', num_labels=1)
outputs = model(inputs_embeds=inputs_embeds, labels=ages)

BertForSequenceClassificationの使い方については,タスクは違いますが以下の記事がおすすめです.

まとめ

ここまでありがとうございました.
テキストデータとラベルデータの両方を考慮させたい!となった時に「BERTのアーキテクチャから自作しないと...」と思ったら便利な機能があったので記事にしてみました.
いいねを貰えると励みになるのでこの記事がいいなと思ったらぜひいいねをお願いします.
ミスなどがありましたらコメントください.

4
6
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?