5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PyTorch + TorchText で日本語文書を分類するためのメモ ( LSTM、Attention )

Posted at

ライブラリー

import os
import re

import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import MeCab
tagger = MeCab.Tagger('-Owakati')

import torch
import torch.nn as nn
import torchtext
import torch.optim as optim

SEED = 1234

torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

データの前処理

df_movie = pd.read_csv('./data/movie.csv')
df_sports = pd.read_csv('./data/sports.csv')

file = './data/train.tsv'
with open(file, 'w') as f:
    titles = df_movie['title'].values[:500]
    for title in titles:
        text = title
        text = text.replace('\n', '')
        text = text.replace('\t', ' ')
        text = text + '\t' + '1' + '\t' + '\n'
        f.write(text)
        
    titles = df_sports['title'].values[:500]
    for title in titles:
        text = title
        text = text.replace('\n', '')
        text = text.replace('\t', ' ')
        text = text + '\t' + '0' + '\t' + '\n'
        f.write(text)

file = './data/test.tsv'
with open(file, 'w') as f:
    titles = df_movie['title'].values[500:]
    for title in titles:
        text = title
        text = text.replace('\n', '')
        text = text.replace('\t', ' ')
        text = text + '\t' + '1' + '\t' + '\n'
        f.write(text)
        
    titles = df_sports['title'].values[500:]
    for title in titles:
        text = title
        text = text.replace('\n', '')
        text = text.replace('\t', ' ')
        text = text + '\t' + '0' + '\t' + '\n'
        f.write(text)

df_sports.head()

image.png

def tokenizer(text):
  sentence = tagger.parse(text)
  sentence = re.sub(r'[0-90-9a-zA-Za-zA-Z]+', " ", sentence)
  sentence = re.sub(r'[\._-―─!@#$%^&\-‐|\\*\“()_■×+α※÷⇒—●★☆〇◎◆▼◇△□(:〜~+=)/*&^%$#@!~`){}[]…\[\]\"\'\”\’:;<>?<>〔〕〈〉?、。・,\./『』【】「」→←○《》≪≫\n\u3000]+', "", sentence)
  return sentence.split()  

max_length = 25

TEXT = torchtext.data.Field(sequential=True, tokenize=tokenizer, use_vocab=True,
                           include_lengths=True, batch_first=True, fix_length=max_length)
LABEL = torchtext.data.Field(sequential=False, use_vocab=False)

path = 'xxx/train.tsv'
train_ds = torchtext.data.TabularDataset(path=path, format='tsv',  
        fields=[('Text', TEXT), ('Label', LABEL)])

path = 'xxx/test.tsv'
test_ds = torchtext.data.TabularDataset(path=path, format='tsv',  
        fields=[('Text', TEXT), ('Label', LABEL)])
examples = train_ds.examples
print(len(examples))

idx = 0
print(examples[idx].Text)
print(examples[idx].Label)
TEXT.build_vocab(train_ds, min_freq=1)

print(vars(TEXT.vocab).keys())
print(len(TEXT.vocab.freqs.keys()))
print(TEXT.vocab.freqs)
print(TEXT.vocab.stoi)
print(len(TEXT.vocab))
train_dl = torchtext.data.Iterator(train_ds, batch_size=4, train=True)
test_dl = torchtext.data.Iterator(test_ds, batch_size=4, train=False, sort=False)

batch = next(iter(train_dl))
print(batch.Text)
print(batch.Label)

image.png

モデル1

class LSTM(nn.Module):   # one layer, unidirectional
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):        
        super(LSTM, self).__init__()        
        self.embedding = nn.Embedding(input_dim, embedding_dim)        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)        
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, text):
        embedded = self.embedding(text)
        outputs, (h_n, c_n) = self.lstm(embedded)
        output = self.fc(h_n.squeeze(0))        
        return output

input_dim = len(TEXT.vocab)
embedding_dim = 100
hidden_dim = 128
output_dim = 1

model = LSTM(input_dim, embedding_dim, hidden_dim, output_dim)
print(model)
print()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

model = model.to(device)
criterion = criterion.to(device)

学習

def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc
  
def train(model, iterator, optimizer, criterion):    
    model.train()
    
    epoch_loss = 0
    epoch_acc = 0
    for batch in iterator:
        inputs = batch.Text[0].to(device)
        targets = batch.Label.float().to(device)
      
        optimizer.zero_grad()                
        preds = model(inputs).squeeze(1)        
        loss = criterion(preds, targets)        
        acc = binary_accuracy(preds, targets)
        
        loss.backward()        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)
  
def evaluate(model, iterator, criterion):
    model.eval()
    
    epoch_loss = 0
    epoch_acc = 0
  
    with torch.no_grad():
    
        for batch in iterator:
            inputs = batch.Text[0].to(device)
            targets = batch.Label.float().to(device)
            
            preds = model(inputs).squeeze(1)            
            loss = criterion(preds, targets)            
            acc = binary_accuracy(preds, targets)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)  

n_epochs = 10

train_loss_hist = []
train_acc_hist = []
valid_loss_hist = []
valid_acc_hist = []

best_valid_loss = float('inf')

for epoch in range(n_epochs):
  
    train_loss, train_acc = train(model, train_dl, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, test_dl, criterion)
    
    train_loss_hist.append(train_loss)
    train_acc_hist.append(train_acc)
    valid_loss_hist.append(valid_loss)
    valid_acc_hist.append(valid_acc)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        #torch.save(model.state_dict(), 'tut1-model.pt')
    
    if (epoch+1)%1 == 0:
        print(f'Epoch: {epoch+1:02}')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

image.png

結果

print(best_valid_loss)

plt.figure(figsize =(5,3))
plt.plot(train_loss_hist, marker='.', label='train')
plt.plot(valid_loss_hist, marker='.', label='validation')
plt.title('Loss')
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc='best')
plt.show()

plt.figure(figsize =(5,3))
plt.plot(train_acc_hist, marker='.', label='train')
plt.plot(valid_acc_hist, marker='.', label='validation')
plt.title('Accuracy')
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(loc='best')
plt.show()

image.png

model.eval()

idx = np.random.randint(200)
text = df_test.iloc[idx, 0].split('\t')[0]
label = df_test.iloc[idx, 0].split()[1]

tokens = tokenizer(text)
while len(tokens) < max_length:
  tokens.append('<pad>')
  
inputs = []
for token in tokens:
  inputs.append(TEXT.vocab.stoi[token])

inputs = torch.LongTensor(inputs).unsqueeze(0).to(device)
preds = model(inputs).squeeze(1)
rounded_preds = torch.round(torch.sigmoid(preds)).detach().cpu().numpy().astype(np.uint8)

print('text: ', text)
print('truth: ', label)
print('prediction: ', rounded_preds.item())

image.png

image.png

image.png

モデル2

class LSTM(nn.Module):   # one layer, bidirectional
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):        
        super(LSTM, self).__init__()        
        self.embedding = nn.Embedding(input_dim, embedding_dim)        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1,
                            batch_first=True, bidirectional=True)        
        self.fc = nn.Linear(hidden_dim*2, output_dim)
        
    def forward(self, text):
        embedded = self.embedding(text)
        output, (h_n, c_n) = self.lstm(embedded)
        output = torch.cat((h_n[0], h_n[1]), dim=1)
        output = self.fc(output)        
        return output

モデル3 Attention

class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):        
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1,
                            batch_first=True, bidirectional=True)        
        
    def forward(self, text):
        batch_size = text.size()[0]
        embedded = self.embedding(text)
        output, (h_n, c_n) = self.lstm(embedded)
        forward_output = output.view(batch_size, -1, 2, hidden_dim)[:, :, 0, :]
        backward_output = output.view(batch_size, -1, 2, hidden_dim)[:, :, 1, :]
        return forward_output + backward_output
      
class Attention(nn.Module):
  def __init__(self, hidden_dim):
    super(Attention, self).__init__()
    self.hidden_dim = hidden_dim
    self.seq = nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim*2),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim*2, 1)
    )
    
  def forward(self, encoder_output):
    batch_size = encoder_output.size()[0]
    temp = self.seq(encoder_output.view(-1, self.hidden_dim))
    return nn.functional.softmax(temp.view(batch_size, -1), dim=1).unsqueeze(2)
  
class Classifier(nn.Module):
  def __init__(self, hidden_dim, output_dim):
    super(Classifier, self).__init__()
    self.attention = Attention(hidden_dim)
    self.linear = nn.Linear(hidden_dim, output_dim)
    
  def forward(self, encoder_output):
    scores = self.attention(encoder_output)
    logits = (encoder_output * scores).sum(dim=1)
    logits = self.linear(logits)
    return logits, scores
  
class LSTM_Attention(nn.Module):
  def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
    super(LSTM_Attention, self).__init__()
    self.encoder = Encoder(input_dim, embedding_dim, hidden_dim, output_dim)
    self.attention = Attention(hidden_dim)
    self.linear = nn.Linear(hidden_dim, output_dim)
    
  def forward(self, text):
    encoder_output = self.encoder(text)
    scores = self.attention(encoder_output)
    logits = (encoder_output * scores).sum(dim=1)
    logits = self.linear(logits)
    return logits, scores 
input_dim = len(TEXT.vocab)
embedding_dim = 100
hidden_dim = 128
output_dim = 1

# encoder = Encoder(input_dim, embedding_dim, hidden_dim, output_dim)
# attention = Attention(hidden_dim)
# classifier = Classifier(hidden_dim, output_dim)
model = LSTM_Attention(input_dim, embedding_dim, hidden_dim, output_dim)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  
# print(f'The model has {count_parameters(encoder):,} trainable parameters')
# print(f'The model has {count_parameters(attention):,} trainable parameters')
# print(f'The model has {count_parameters(classifier):,} trainable parameters')
print(f'The model has {count_parameters(model):,} trainable parameters')
criterion = nn.BCEWithLogitsLoss()
model = model.to(device)
criterion = criterion.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc
  
def train(model, iterator, optimizer, criterion):    
    model.train()
    
    epoch_loss = 0
    epoch_acc = 0
    for batch in iterator:
        inputs = batch.Text[0].to(device)
        targets = batch.Label.float().to(device)
      
        optimizer.zero_grad()                
        preds, scores = model(inputs)
        preds = preds.squeeze(1)
        loss = criterion(preds, targets)        
        acc = binary_accuracy(preds, targets)
        
        loss.backward()        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)
  
def evaluate(model, iterator, criterion):
    model.eval()
    
    epoch_loss = 0
    epoch_acc = 0
  
    with torch.no_grad():
    
        for batch in iterator:
            inputs = batch.Text[0].to(device)
            targets = batch.Label.float().to(device)
            
            preds, scores = model(inputs)
            preds = preds.squeeze(1)
            loss = criterion(preds, targets)            
            acc = binary_accuracy(preds, targets)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)  

n_epochs = 10

train_loss_hist = []
train_acc_hist = []
valid_loss_hist = []
valid_acc_hist = []

best_valid_loss = float('inf')

for epoch in range(n_epochs):
  
    train_loss, train_acc = train(model, train_dl, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, test_dl, criterion)
    
    train_loss_hist.append(train_loss)
    train_acc_hist.append(train_acc)
    valid_loss_hist.append(valid_loss)
    valid_acc_hist.append(valid_acc)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        #torch.save(model.state_dict(), 'tut1-model.pt')
    
    if (epoch+1)%2 == 0:
        print(f'Epoch: {epoch+1:02}')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')
batch = next(iter(train_dl))
text = batch.Text[0]
label = batch.Label

# print(text[0])
print(label[0].numpy())

s = ''
s_list = []
for i in text[0]:
  s += TEXT.vocab.itos[i]
  s_list.append(TEXT.vocab.itos[i])  
print(s)

model.eval()
preds, scores = model(text.to(device))
preds = preds.squeeze(1)
rounded_preds = torch.round(torch.sigmoid(preds))
rounded_preds = rounded_preds.detach().cpu().numpy()[0].astype(np.uint8)
atts = scores.detach().cpu().numpy()[0,:,0]
argsort_atts = atts.argsort()[::-1]
print(rounded_preds)
# print(atts.shape)
print(argsort_atts[:3])
print(s_list[argsort_atts[0]], s_list[argsort_atts[1]], s_list[argsort_atts[2]])
5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?