0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

huggingface/transformersのALBERTで分類タスクのfine-tuningをする

Last updated at Posted at 2023-11-07

背景

  • 事前学習モデルとしてALBERT使いたい
  • とりあえず、huggingface/transformersのfinetuningのページ見ながら作ってみる
    • データはyelpのレビューデータを使う

実装

from datasets import load_dataset
from transformers import AutoTokenizer, AlbertForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import evaluate
import torch

# train: 650,000件, test: 50,000件のデータをロード
dataset = load_dataset("yelp_review_full")
# 分類用のクラスでロード. 分類は5つ.
model = AlbertForSequenceClassification.from_pretrained("albert-base-v1", num_labels=5)
tokenizer = AutoTokenizer.from_pretrained("albert-base-v1")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=model.config.max_position_embeddings)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
# 時短のため訓練データ1,000、評価用のデータを100件にする
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(100))

# 評価はaccuracyを使う
metric = evaluate.load("accuracy")

# 評価用処理
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

 # epoch毎にモデルと評価を出力
training_args = TrainingArguments(
    output_dir="test_trainer",
    evaluation_strategy="epoch",
    save_strategy="epoch"
  )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

推論コード

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PATH = "/content/test_trainer/checkpoint-375/pytorch_model.bin"
model.load_state_dict(torch.load(PATH))
model.to(device)

# 絶対評価1だろっていう文章と評価5だろっていう文章を用意
tokenized_list = tokenizer(["I never watch this again. it's just waste of time.", "I love this moview. this is the best one ever."], padding='max_length', truncation=True)

# 予測
with torch.no_grad():
      outputs = model(
          input_ids=torch.tensor(tokenized_list['input_ids'][0])[None].to(device),
          attention_mask=torch.tensor(tokenized_list['attention_mask'][0])[None].to(device)
      )
      logits = outputs.logits
      print("1つ目: ", logits)
      
      outputs = model(
          input_ids=torch.tensor(tokenized_list['input_ids'][1])[None].to(device),
          attention_mask=torch.tensor(tokenized_list['attention_mask'][1])[None].to(device)
      )
      logits = outputs.logits
      print("2つ目: ", logits)

出力. 良さそう.

1つ目:  tensor([[ 1.8693,  0.5729, -0.0059, -0.9203, -1.3416]], device='cuda:0')
2つ目:  tensor([[-1.9428, -0.8501, -0.8619,  1.4056,  2.5438]], device='cuda:0')
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?