こんにちにゃんです。
水色桜(みずいろさくら)です。
今回はT5をファインチューニングして、タイトル生成を行ってみようと思います。
今回作成したモデルはこちら(hugging face)で配布しています。
ぜひ用いてみてください。
記事中で何か不明な点・間違いなどありましたらコメントかTwitterまでお寄せいただけると嬉しいです。
T5とは
Text-to-Text Transfer Transformerの略称。入力と出力の両方をテキストのフォーマットに統一して、転移学習を行うことがこのモデルの特徴です。
Googleが発表した論文である「Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer」の中で、登場しました。同じくGoogleによって開発されたBERTがクラスラベルや入力の範囲など、人間が言語としてそのまま理解できないデータしか出力できないのに対し、T5は入力と出力が常にテキスト形式になるように自然言語処理タスクを再構成するため、機械翻訳や文書要約に柔軟に対応することが可能であるとのこと。
https://gigazine.net/news/20200225-google-ai-t5/ より引用
学習
以下のコードを用いて学習を行いました。
from datasets import load_dataset
dataset=load_dataset("shunk031/CAMERA", name="without-lp-images")
print(dataset["train"][0]["title_org"])
from transformers import T5Tokenizer, T5ForConditionalGeneration
import ast
from tqdm import tqdm
tqdm.pandas()
import torch
tokenizer = T5Tokenizer.from_pretrained('sonoisa/t5-base-japanese')
model = T5ForConditionalGeneration.from_pretrained('sonoisa/t5-base-japanese')
max_seq=64
train_input_ids=[]
train_attention_mask=[]
train_labels=[]
for i in range(len(dataset["train"])):
train_input_ids.append(torch.tensor(tokenizer(dataset["train"][i]["lp_meta_description"], max_length=max_seq, padding='max_length', truncation=True).input_ids))
train_attention_mask.append(torch.tensor(tokenizer(dataset["train"][i]["lp_meta_description"], max_length=max_seq, padding='max_length', truncation=True).attention_mask))
train_labels.append(torch.tensor(tokenizer(dataset["train"][i]["title_org"], max_length=max_seq, padding='max_length', truncation=True).input_ids))
valid_input_ids=[]
valid_attention_mask=[]
valid_labels=[]
for i in range(len(dataset["validation"])):
valid_input_ids.append(torch.tensor(tokenizer(dataset["validation"][i]["lp_meta_description"], max_length=max_seq, padding='max_length', truncation=True).input_ids))
valid_attention_mask.append(torch.tensor(tokenizer(dataset["validation"][i]["lp_meta_description"], max_length=max_seq, padding='max_length', truncation=True).attention_mask))
valid_labels.append(torch.tensor(tokenizer(dataset["validation"][i]["title_org"], max_length=max_seq, padding='max_length', truncation=True).input_ids))
from torch.utils.data import Dataset
import gc
del dataset
gc.collect()
# データセットの作成
class CQADataset(Dataset):
def __init__(self, input,attention,label, is_test=False):
self.input = input
self.attention=attention
self.label=label
self.is_test = is_test
def __getitem__(self, idx):
data = {'input_ids': torch.tensor(self.input[idx])}
data['attention_mask']= torch.tensor(self.attention[idx])
data['labels']= torch.tensor(self.label[idx])
return data
def __len__(self):
return len(self.input)
import numpy as np
dataset_train = CQADataset(train_input_ids, train_attention_mask, train_labels)
dataset_valid = CQADataset(valid_input_ids, valid_attention_mask, valid_labels)
from transformers import Trainer, TrainingArguments
training_config = TrainingArguments(
output_dir = 'C://Users//desktop//Python//t5_title',
num_train_epochs = 5,
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
learning_rate = 1e-5,
weight_decay = 0.1,
save_steps = 7750,
eval_steps=7750,
do_eval = True
)
trainer = Trainer(
model=model,
args=training_config,
train_dataset=dataset_train,
eval_dataset=dataset_valid,
#compute_metrics = compute_metrics
)
trainer.train()
推論の実行
以下のサンプルコードを用いることで推論を行うことができます。
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
tokenizer = T5Tokenizer.from_pretrained('sonoisa/t5-base-japanese')
model = T5ForConditionalGeneration.from_pretrained('Mizuiro-sakura/t5-CAMERA-title-generation')
text = "ニューラルネットワークとは人間の脳の神経回路の構造を数学的に表現する手法です。ニューラルネットワークはPythonによって構成されることが多いです。"
max_seq_length=256
token=tokenizer(text,
truncation=True,
max_length=max_seq_length,
padding="max_length")
output=model.generate(input_ids = torch.tensor(token['input_ids']).unsqueeze(0), attention_mask = torch.tensor(token['attention_mask']).unsqueeze(0))
output_decode=tokenizer.decode(output[0], skip_special_tokens=True)
print(output_decode)
終わりに
私自身の備忘録のためにも学習のコードなどを書き残しておこうと思い、今回の記事を書いてみました。
同様のファインチューニングをしてみたいという方の助けになれば幸いです。
では、ばいにゃん~。