0
1

More than 1 year has passed since last update.

BERTで文の埋め込みを取得

Last updated at Posted at 2022-12-06

はじめに

学習済みのBERTを使って,文の埋め込みを取得しようと思っています.色々調べてみたら,こういうコードがなさそうので,一応共有しとく.

今回はBERTを使って,mean-pooling したベクトルを計算する.
環境はColab.

Code

必要なライブラリをダウンロード

import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

# GPU を利用する,GPUがなかった人はこのステップをスキップして
device = torch.device("cuda:0")
model = model.to(device)

ベクトルを取得

    # GPU を使わない人は.to(device)  を削除

def get_word_embedding(text:str):
    input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)  # Batch size 1
    input_ids = input_ids.to(device)

    outputs = model(input_ids)
    last_hidden_states = outputs[1]
    last_hidden_states = last_hidden_states.to(device)  
    # The last hidden-state is the first element of the output tuple
    return last_hidden_states[0].detach().to(device)

感想

私は,大規模テキストデータの文の埋め込みを取得しようと思った.最初はGPUを使わず,CPUを使ってた.そして,RAMが80GBのCPUはメモリオーバーしてしまった.GPUを使う場合,問題なく実行できた.そのため,大規模データの埋め込みを取得したい場合,GPUの利用が必要だ.

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1