5
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

言語モデルRoBERTaで深層距離学習モデルを作ってみる

Last updated at Posted at 2022-01-31

#概要
 BERTベースの深層距離学習をやってみます.今回は深層距離学習の中でも,Siamese Networkと呼ばれるモデルを作っていきます.

#準備
 python = "3.6.8"
 pytorch = "1.6.0"
 データセットはQuora Question Pairsを使わせていただきます.
 事前学習モデルはhugging faceのroberta-baseを使わせていただきます.

train.csv
データID,質問1のID,質問2のID,質問1の内容,質問2の内容,ラベル

 このような6項目で構築され,ラベルは対の2データに対して類似/非類似 (1 or 0)が与えられています.

#学習方法
 ・Siamese Networkでは以下の損失関数が使われます.
 ・Yはラベル
 ・Dは比較するデータ埋め込みのユークリッド距離
 ・mは定数

L=\dfrac{1}{2}\left[ YD^{2}+\left( 1-Y\right) \left[ ReLU\left( m-D\right) \right] ^{2}\right]

 ・最小化することで,ラベルが1のデータ対を近づけるような学習が行われます.ラベルが0のデータ対は離されます.

#コード
・まずはデータセットを扱いやすい形に加工します.
・{質問のID:内容} の辞書と,[ID1,ID2,ラベル] のリストのpickleを作っておきます.

make_dataset.py
import csv
import pickle
with open("./quora-question-pairs/train.csv","r",encoding="utf-8") as r:
    rows = csv.reader(r)
    lines = [line for line in rows]

lines = lines[1:]

id_text = {}
for line in lines:
    id_text[line[1]] = line[3]
    id_text[line[2]] = line[4]

id1_2_label = []
for line in lines:
    id1_2_label.append(line[1],line[2],line[-1])

with open("./quora-question-pairs/id_test_text.pikcle","wb") as wb:
    pickle.dump(id_text,wb)

with open("./quora-question-pairs/id1_2_label.pikcle","wb") as wb:
    pickle.dump(id1_2_label,wb)

・インポート類

main.py
import numpy as np
import pickle
from tqdm import tqdm

import torch
torch.manual_seed(40)
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset

import transformers
from transformers import AutoTokenizer, AutoModel, AutoConfig

・加工したデータを読み込む

main.py
with open("./quora-question-pairs/id_text.pikcle","rb") as rb:
    id_text = pickle.load(rb)

with open("./quora-question-pairs/id1_2_label.pikcle","rb") as rb:
    id1_2_label = pickle.load(rb)

・事前学習モデルとトークナイザを用意

main.py
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = AutoConfig.from_pretrained('roberta-base')
config.output_hidden_states=True

tokenizer = AutoTokenizer.from_pretrained('roberta-base')
roberta = AutoModel.from_pretrained('roberta-base',config=config).to(device)

・データセットを用意

main.py
class QuoraDatasets(Dataset):
    def __init__(self):
        self.id_data = id_text
        self.id1_2_label = id1_2_label
        self.datanum = len(self.id1_2_label)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        id1,id2,label = self.id1_2_label[idx]
        encode1 = tokenizer(self.id_data[id1],truncation=True,padding='max_length',return_tensors="pt")
        tokens1 = encode1["input_ids"]
        attention1 = encode1["attention_mask"]
        
        encode2 = tokenizer(self.id_data[id2],truncation=True,padding='max_length',return_tensors="pt")
        tokens2 = encode2["input_ids"]
        attention2 = encode2["attention_mask"]

        label = float(label)
        return tokens1.squeeze(),attention1.squeeze(),\
                tokens2.squeeze(),attention2.squeeze(),\
                torch.tensor([label]).squeeze()

・Siamese Networkを書きます
・marginは定数mを指し,1.0に設定(最適化後に全距離が0.0~1.0くらいになるように埋め込みが学習できます)

main.py
class SiameseNet(nn.Module):
    def __init__(self,lm):
        super(SiameseNet,self).__init__()
        self.margin = 1.0
        self.lm = lm
        self.eps = 1e-9
        self.linear = nn.Linear(768,768)

    def contrastive_loss(self,v1,v2,label):
        distances = (v2 - v1).pow(2).mean(1)
        losses = 0.5 * (label.float() * distances + (1 + -1 * label).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
        return losses,distances

    def forward(self,s1,a1,s2,a2,label):
        hidden1 = self.lm(input_ids=s1,attention_mask=a1).last_hidden_state[:,0,:]
        hidden2 = self.lm(input_ids=s2,attention_mask=a2).last_hidden_state[:,0,:]
        hidden1 = self.linear(hidden1)
        hidden2 = self.linear(hidden2)
        losses,distances = self.contrastive_loss(hidden1,hidden2,label)
        return losses.mean(),distances

model = SiameseNet(roberta).to(device)

・学習ループを書きます

main.py
batch_size = 4
epoch = 20

dataset = QuoraDatasets()
length = len(dataset)
train,val = torch.utils.data.random_split(dataset,[300000,length-300000])

trainloader = DataLoader(train,batch_size=batch_size,shuffle=True)
valloader = DataLoader(train,batch_size=batch_size,shuffle=False)

optimizer = optim.Adam(model.parameters(),lr=0.00001)

for i in range(epoch):
    running_loss = 0
    step = 0
    model.train()
    for batch in tqdm(trainloader):
        step += 1
        s1,a1,s2,a2,label = batch
        s1 = s1.long().to(device)
        a1 = a1.long().to(device)
        s2 = s2.long().to(device)
        a2 = a2.long().to(device)
        label = label.float().to(device)

        optimizer.zero_grad()
        loss,_ = model(s1,a1,s2,a2,label)
        loss.backward()
        now_loss = loss.item()
        running_loss += now_loss
        optimizer.step()

    running_loss /= step
    print("train_loss:",running_loss)

    val_running_loss = 0
    step = 0
    model.eval()
    for batch in valloader:
        step += 1
        s1,a1,s2,a2,label = batch
        s1 = s1.long().to(device)
        a1 = a1.long().to(device)
        s2 = s2.long().to(device)
        a2 = a2.long().to(device)
        label = label.float().to(device)
        
        with torch.no_grad():
            loss,distances = model(s1,a1,s2,a2,label)
        now_loss = loss.item()
        val_running_loss += now_loss

    val_running_loss /= step
    print("val_loss:",val_running_loss)

#結果
・評価ロスが下がっていくのが確認できると思います.

#備考
・以下の式をDに代入して学習すれば,ユークリッド距離の代わりにコサイン類似度で深層距離学習ができる気がします.

・v1,v2はデータ対の埋め込みでδは非常に小さい数を設定します.

D=Exp\left\{ -CosSim\left( v_{1},v_{2}\right) \right\} -\left( \dfrac{1}{Exp\left( 1\right) }\right) +\delta
distances = torch.exp(-cos)-0.367  #-1/exp(1)+δ を-0.367と設定

#まとめ
・言語モデルベースの深層距離学習を書きました.
・対のデータ埋め込みの距離に対して簡単に0.5の閾値を定めると,通常の学習方法(ラベルの最尤推定など)より0.6%ほど低い値になっていました.

#最後に
・間違っている点などございましたら,コメントなどで優しく指摘して頂けると助かります.(気付かなければ申し訳ありません)

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?