12
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

T5をファインチューニングしてタイトル生成を行ってみた

Posted at

こんにちにゃんです。
水色桜(みずいろさくら)です。
今回はT5をファインチューニングして、タイトル生成を行ってみようと思います。
今回作成したモデルはこちら(hugging face)で配布しています。
ぜひ用いてみてください。
記事中で何か不明な点・間違いなどありましたらコメントかTwitterまでお寄せいただけると嬉しいです。

T5とは

Text-to-Text Transfer Transformerの略称。入力と出力の両方をテキストのフォーマットに統一して、転移学習を行うことがこのモデルの特徴です。

Googleが発表した論文である「Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer」の中で、登場しました。同じくGoogleによって開発されたBERTがクラスラベルや入力の範囲など、人間が言語としてそのまま理解できないデータしか出力できないのに対し、T5は入力と出力が常にテキスト形式になるように自然言語処理タスクを再構成するため、機械翻訳や文書要約に柔軟に対応することが可能であるとのこと。

https://gigazine.net/news/20200225-google-ai-t5/ より引用

image.png

学習

以下のコードを用いて学習を行いました。

train.py
from datasets import load_dataset
dataset=load_dataset("shunk031/CAMERA", name="without-lp-images")
print(dataset["train"][0]["title_org"])

from transformers import T5Tokenizer, T5ForConditionalGeneration
import ast
from tqdm import tqdm
tqdm.pandas()
import torch

tokenizer = T5Tokenizer.from_pretrained('sonoisa/t5-base-japanese')
model = T5ForConditionalGeneration.from_pretrained('sonoisa/t5-base-japanese')
max_seq=64
train_input_ids=[]
train_attention_mask=[]
train_labels=[]
for i in range(len(dataset["train"])):
    train_input_ids.append(torch.tensor(tokenizer(dataset["train"][i]["lp_meta_description"], max_length=max_seq, padding='max_length', truncation=True).input_ids))
    train_attention_mask.append(torch.tensor(tokenizer(dataset["train"][i]["lp_meta_description"], max_length=max_seq, padding='max_length', truncation=True).attention_mask))
    train_labels.append(torch.tensor(tokenizer(dataset["train"][i]["title_org"], max_length=max_seq, padding='max_length', truncation=True).input_ids))

valid_input_ids=[]
valid_attention_mask=[]
valid_labels=[]
for i in range(len(dataset["validation"])):
    valid_input_ids.append(torch.tensor(tokenizer(dataset["validation"][i]["lp_meta_description"], max_length=max_seq, padding='max_length', truncation=True).input_ids))
    valid_attention_mask.append(torch.tensor(tokenizer(dataset["validation"][i]["lp_meta_description"], max_length=max_seq, padding='max_length', truncation=True).attention_mask))
    valid_labels.append(torch.tensor(tokenizer(dataset["validation"][i]["title_org"], max_length=max_seq, padding='max_length', truncation=True).input_ids))

from torch.utils.data import Dataset
import gc
del dataset
gc.collect()
# データセットの作成
class CQADataset(Dataset):
  def __init__(self, input,attention,label, is_test=False):
    self.input = input
    self.attention=attention
    self.label=label
    self.is_test = is_test

  def __getitem__(self, idx):
    data = {'input_ids': torch.tensor(self.input[idx])}
    data['attention_mask']= torch.tensor(self.attention[idx])
    data['labels']= torch.tensor(self.label[idx])
    return data

  def __len__(self):
    return len(self.input)
import numpy as np
dataset_train = CQADataset(train_input_ids, train_attention_mask, train_labels)
dataset_valid = CQADataset(valid_input_ids, valid_attention_mask, valid_labels)

from transformers import Trainer, TrainingArguments

training_config = TrainingArguments(
  output_dir = 'C://Users//desktop//Python//t5_title',
  num_train_epochs = 5, 
  per_device_train_batch_size = 8,
  per_device_eval_batch_size = 8,
  learning_rate = 1e-5,
  weight_decay = 0.1,
  save_steps = 7750,
  eval_steps=7750,
  do_eval = True
  
)

trainer = Trainer(
    model=model,
    args=training_config,
    train_dataset=dataset_train,
    eval_dataset=dataset_valid,
    #compute_metrics = compute_metrics
)

trainer.train()

推論の実行

以下のサンプルコードを用いることで推論を行うことができます。

sample.py
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

tokenizer = T5Tokenizer.from_pretrained('sonoisa/t5-base-japanese')
model = T5ForConditionalGeneration.from_pretrained('Mizuiro-sakura/t5-CAMERA-title-generation')

text = "ニューラルネットワークとは人間の脳の神経回路の構造を数学的に表現する手法です。ニューラルネットワークはPythonによって構成されることが多いです。"
max_seq_length=256
token=tokenizer(text,
        truncation=True,
        max_length=max_seq_length,
        padding="max_length")

output=model.generate(input_ids = torch.tensor(token['input_ids']).unsqueeze(0), attention_mask = torch.tensor(token['attention_mask']).unsqueeze(0))
output_decode=tokenizer.decode(output[0], skip_special_tokens=True)

print(output_decode)

終わりに

私自身の備忘録のためにも学習のコードなどを書き残しておこうと思い、今回の記事を書いてみました。
同様のファインチューニングをしてみたいという方の助けになれば幸いです。
では、ばいにゃん~。

12
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?