0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【Pytorch】Seq2Seqのためのオペレータ?を作る

Last updated at Posted at 2020-07-05

#動機

深層学習モデルを設計したあと、訓練と検証のためのコードを毎回手続き型で記述・・・

みたいなことを卒業研究の際に、試したいモデルが出るたびにやってました。
コピペするとはいえ、毎回毎回記述するのが〇ッッッッソ面倒なので引数だけ変えたらあとは勝手に訓練と検証を一括でやってくれるオペレータを作りたいと思います。

(オペレータという言い方が正しいのかはわかりません)

#コード

Operator.py
from utils import to_np, trim_seqs
from torch.autograd import Variable
import torch
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.load("./data/index.model")

class Operator:
    def __init__(self, model, optimizer, criterion):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion

    def run_epoch(self, epoch, train_loader, eval_loader):
        self.model.train()
        losses = []
        for idx, data in enumerate(train_loader):
            self.optimizer.zero_grad()
            input_data = Variable(data[0].cuda())
            target = data[1].cuda()
            target_y = Variable(target[:, :-1])
            target = Variable(target[:, 1:])
            out, _ = self.model(input_data, target)
            loss = self.loss_compute(out, target_y, True)
            losses.append(to_np(loss))
        train_loss = sum(losses) / len(losses)
        eval_loss, bleuscore = self.evaluate(out, eval_loader)
        print("epochs: {}  train_loss: {}  eval_loss: {}  val_bleuscore: {}".format(epoch+1, train_loss, eval_loss, bleuscore))

    def loss_compute(self, out, target, flag=False):
        self.optimizer.zero_grad()
        loss = self.criterion(out.contiguous().view(-1, out.size(-1)), target.contiguous().view(-1))
        if flag:
            loss.backward()
            self.optimizer.step()
        return loss

    def evaluate(self, out, loader):
        self.model.eval()
        losses = []
        all_output_seqs = []
        all_target_seqs = []
        for idx, data in enumerate(loader):
            with torch.no_grad():
                sampled_index = []
                decoder_outputs = []
                sampled_idxs = []
                input_data = data[0].cuda()
                target = data[1].cuda()
                target_y = target[:, 1:]
                _, hidden = self.model.encode(input_data)
                start_symbol = [[sp.PieceToId("<s>")] for i in range(input_data.size(0))]
                decoder_input = torch.tensor(start_symbol).cuda()
                for i in range(input_data.size(1)):
                    decoder_output, hidden = self.model.decode(decoder_input, hidden)
                    _,topi = torch.topk(decoder_output, 1, dim=-1)
                    decoder_outputs.append(decoder_output)
                    sampled_idxs.append(topi)
                    decoder_input = topi.squeeze(1)
                sampled_idxs = torch.stack(sampled_idxs, dim=1)
                decoder_outputs = torch.stack(decoder_outputs, dim=1)
                sampled_idxs = sampled_idxs.squeeze()
                decoder_outputs = decoder_outputs.squeeze()
                loss = self.loss_compute(decoder_outputs, target_y)
                all_output_seqs.extend(trim_seqs(sampled_idxs))
                all_target_seqs.extend([list(seq[seq > 0])] for seq in to_np(target))
                losses.append(to_np(loss))
        bleu_score = corpus_bleu(all_target_seqs, all_output_seqs, smoothing_function=SmoothingFunction().method1)
        mean_loss = sum(losses) / len(losses)
        self.generator(all_output_seqs, all_target_seqs, input_data.size(1))
        return mean_loss, bleu_score

    def generator(self, all_output_seqs, all_target_seqs, maxlen):
        with open("./log/result.txt", "w") as f:
            for sentence in all_output_seqs:
                for tok in sentence:
                    f.write(sp.IdToPiece(int(tok)))
                f.write("\n")

        with open("./log/target.txt", "w") as f:
            for sentence in all_target_seqs:
                for tok in sentence[0]:
                    f.write(sp.IdToPiece(int(tok)))
                f.write("\n")

形態素解析器はsentencepieceを使ってます。
ちなみにモデルは以下のようなものを想定しています。

model.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, data, target):
        out, hx = self.encoder(data)
        return self.decoder(target, hx)
    def decode(self, target, hx):
        return self.decoder(target, hx)
    def encode(self, input_data):
        return self.encoder(input_data)

class Encoder(nn.Module):
    def __init__(self, base_module):
        super(Encoder, self).__init__()
        self.base_module = base_module

    def forward(self, data):
        return self.base_module(data)

class Decoder(nn.Module):
    def __init__(self, base_module, maxlen):
        super(Decoder, self).__init__()
        self.base_module = base_module
        self.maxlen = maxlen

    def forward(self, data, hx):
        return self.base_module(data, hx)

#これから
これ使って色々なモデルを訓練して問題ないか試していきたいです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?