LoginSignup
0
2

More than 1 year has passed since last update.

LUKEをJGLUEのJSTSとJNLI向けにファインチューニングする方法

Last updated at Posted at 2023-02-12

こんにちにゃんです。
水色桜(みずいろさくら)です。
今回はJGLUEの中でもJSTSとJNLIをLUKEで解くためのコードについて、書いていきたいと思います。(私自身の備忘録も兼ねて)
今回作成したモデルはこちら(hugging face)で公開しています。
ぜひ使ってみてください。
もし記事中で不明な点や間違いなどありましたら、コメントもしくはTwitter(@Mizuiro__sakura)までお寄せいただけると嬉しいです。

環境

pandas 1.4.4
numpy 1.23.4
torch 1.12.1
transformer 4.24.0
Python 3.9.13
sentencepiece 0.1.97

LUKE

image.png

2020年4月当時、5つのタスクで世界最高精度を達成した新しい言語モデル。
日本語バージョンのLUKE-baseモデルは執筆現在(2022年12月)も4つのタスク(MARC-ja, JSTS, JNLI, JCommonsenseQA)で最高スコアを有しています。LUKEはRoBERTaを元として構成され、entity-aware self-attentionという独自のメカニズムを用いています。LUKEに関して詳しくは下記記事をご覧ください。

image.png

JGLUE

yahoo JAPAN株式会社さんと早稲田大学河原研究室の共同研究で構築された、日本語言語理解ベンチマーク。言語モデルの包括的な評価のために作成されました。各データはクラウドソーシングにより収集されています。評価データの種類はMARC-ja(二値分類タスク)、JSTS(文章の類似度計算)、JNLI(文章の関係性判別)、JSQuAD(抜き出し型のQAタスク)、JCommonsenseQA(常識を問う5択の選択問題)があります。

ファインチューニングに用いたコード(JSTS)

JSTS用にファインチューニングするためのコードを以下に示します。

jsts_finetuned.py
jsts_finetuned.py
import numpy as np
MODEL_NAME = "studio-ousia/luke-japanese-base"

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
input_ids = []
attention_masks = []

import json
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
dataset_t = [json.loads(line)
        for line in open("jsts_train.json", 'r', encoding='utf-8')]
dataset_v = [json.loads(line)
        for line in open("jsts_valid.json", 'r', encoding='utf-8')]
print(len(dataset_t))
print(len(dataset_v))

train_df = pd.DataFrame.from_dict(dataset_t)
val_df = pd.DataFrame.from_dict(dataset_v)
all_df = pd.concat([train_df, val_df], ignore_index=True, axis=0)


def preprocess_function(examples):
    max_seq_length = 128
    first_sentences = [[sen] for sen in examples['sentence1']]
    second_sentences =[[sen] for sen in examples['sentence2']]
    # トークナイズ(エンコーディング・形態素解析)する
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])
    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        truncation=True,
        max_length=max_seq_length,
        padding="max_length"
    )
    return tokenized_examples

data = preprocess_function(all_df)

print(len(data['input_ids']))


import torch
tensor_input_ids = torch.tensor(data["input_ids"])
tensor_attention_masks = torch.tensor(data["attention_mask"])
# 正解データの取得
labels = all_df["label"].to_list()
tensor_labels = torch.tensor(labels)

import numpy as np
from torch.utils.data import Dataset, Subset

# データセットの作成
class CQADataset(Dataset):
  def __init__(self, input,attention,label, is_test=False):
    self.input = input
    self.attention=attention
    self.label=label
    self.is_test = is_test

  def __getitem__(self, idx):
    data = {'input_ids': torch.tensor(self.input[idx])}
    data['attention_mask']= torch.tensor(self.attention[idx])
    data['label']= [1-torch.tensor(self.label[idx])/5, torch.tensor(self.label[idx])/5]
    return data

  def __len__(self):
    return len(self.input)

dataset = CQADataset(tensor_input_ids,tensor_attention_masks,tensor_labels)
indices = np.arange(len(dataset))

# 訓練データ、評価データに分割
train_dataset = Subset(dataset, indices[:int(len(train_df)*1.0)])
val_dataset = Subset(dataset, indices[-len(val_df):])

from transformers import Trainer, TrainingArguments
model=AutoModelForSequenceClassification.from_pretrained("studio-ousia/luke-japanese-base")

# トレーニングの設定を記述する
training_config = TrainingArguments(
  output_dir = 'C://Users//desktop//Python//jsts',
  num_train_epochs = 3, 
  per_device_train_batch_size = 64,
  per_device_eval_batch_size = 64,
  weight_decay = 0.1,
  do_eval = True,
  save_steps = 570,
)

# トレーニング
trainer = Trainer(
    model = model,                         
    args = training_config,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = val_dataset
)

trainer.train()
print('finished')

ファインチューニングに用いたコード(JNLI)

JNLI用にファインチューニングするためのコードを以下に示します。

jnli_finetuned.py
jnli_finetuned.py
import numpy as np
MODEL_NAME = "studio-ousia/luke-japanese-base"

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
input_ids = []
attention_masks = []

import json
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
dataset_t = [json.loads(line)
        for line in open("jnli_train.json", 'r', encoding='utf-8')]
dataset_v = [json.loads(line)
        for line in open("jnli_valid.json", 'r', encoding='utf-8')]
print(len(dataset_t))
print(len(dataset_v))

train_df = pd.DataFrame.from_dict(dataset_t)
val_df = pd.DataFrame.from_dict(dataset_v)
all_df = pd.concat([train_df, val_df], ignore_index=True, axis=0)


def preprocess_function(examples):
    max_seq_length = 128
    first_sentences = [[sen] for sen in examples['sentence1']]
    second_sentences =[[sen] for sen in examples['sentence2']]
    # トークナイズ(エンコーディング・形態素解析)する
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])
    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        truncation=True,
        max_length=max_seq_length,
        padding="max_length"
    )
    return tokenized_examples

data = preprocess_function(all_df)

print(len(data['input_ids']))


import torch
tensor_input_ids = torch.tensor(data["input_ids"])
tensor_attention_masks = torch.tensor(data["attention_mask"])
# 正解データの取得
labels = all_df["label"].to_list()
labels_label=[]
for i in range(len(labels)):
    if labels[i]=='contradiction':
        labels_label.append(0)
    elif labels[i]=='neutral':
        labels_label.append(1)
    elif labels[i]=='entailment':
        labels_label.append(2)
labels_label=torch.tensor(labels_label)

import numpy as np
from torch.utils.data import Dataset, Subset

# データセットの作成
class CQADataset(Dataset):
  def __init__(self, input,attention,label, is_test=False):
    self.input = input
    self.attention=attention
    self.label=label
    self.is_test = is_test

  def __getitem__(self, idx):
    data = {'input_ids': torch.tensor(self.input[idx])}
    data['attention_mask']= torch.tensor(self.attention[idx])
    data['labels']= [self.label[idx]]
    return data

  def __len__(self):
    return len(self.input)

dataset = CQADataset(tensor_input_ids, tensor_attention_masks, labels_label)
indices = np.arange(len(dataset))

# 訓練データ、評価データに分割
train_dataset = Subset(dataset, indices[:int(len(train_df)*1.0)])
val_dataset = Subset(dataset, indices[-len(val_df):])

from transformers import Trainer, TrainingArguments
model=AutoModelForSequenceClassification.from_pretrained("studio-ousia/luke-japanese-base", num_labels=3)

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    #2値分類ならaverage='binary'とする
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

# トレーニングの設定を記述する
training_config = TrainingArguments(
  output_dir = 'C://Users//desktop//Python//jnli',
  num_train_epochs = 5, 
  per_device_train_batch_size = 256,
  per_device_eval_batch_size = 256,
  weight_decay = 0.1,
  save_strategy = "epoch", #いつ保存するか?
  evaluation_strategy = "epoch", #いつ評価するか?
)

# トレーニング
trainer = Trainer(
    model = model,                         
    args = training_config,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    compute_metrics = compute_metrics, #compute_metrics
)

trainer.train()
print('finished')

躓いた点

ValueError: not enough values to unpack (expected 2, got 1)

解決方法:.unsqueeze(0)をエラーが出ている部分に追記する

終わりに

今回はJSTSとJNLI用にファンチューニングするためのコードについて書いてきました。
JGLUEが整備されたのが最近ということもあり、ほかに解説している記事がないため、記事にしてみました。
ぜひ参考にしてみてください。
では、ばいにゃん~。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2