はじめに
学習済みのBERTを使って,文の埋め込みを取得しようと思っています.色々調べてみたら,こういうコードがなさそうので,一応共有しとく.
今回はBERTを使って,mean-pooling したベクトルを計算する.
環境はColab.
Code
必要なライブラリをダウンロード
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
# GPU を利用する,GPUがなかった人はこのステップをスキップして
device = torch.device("cuda:0")
model = model.to(device)
ベクトルを取得
# GPU を使わない人は.to(device) を削除
def get_word_embedding(text:str):
input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0) # Batch size 1
input_ids = input_ids.to(device)
outputs = model(input_ids)
last_hidden_states = outputs[1]
last_hidden_states = last_hidden_states.to(device)
# The last hidden-state is the first element of the output tuple
return last_hidden_states[0].detach().to(device)
感想
私は,大規模テキストデータの文の埋め込みを取得しようと思った.最初はGPUを使わず,CPUを使ってた.そして,RAMが80GBのCPUはメモリオーバーしてしまった.GPUを使う場合,問題なく実行できた.そのため,大規模データの埋め込みを取得したい場合,GPUの利用が必要だ.