LoginSignup
4
4

More than 1 year has passed since last update.

TwitterのツイートをBERT使って文書分類

Last updated at Posted at 2023-01-01

はじめに

最近Twitterの性能が向上したと聞いて、ツイートのカテゴリー分類が本当にうまく分類できているかどうか気になったので自分で分類してみました。
また、データサイエンティスト系の仕事に興味があったので自分でデータ収集から初めて、AIを作成して分析してみる真似事をやってみました。

実験手法

TwitterからAPI用いてツイートを取得します。
ツイートの分類する分野は、
具体的に、スポーツ・政治・経済・エンタメ・科学
のメジャー5種類にしました。
取得したデータにラベル付与を行い、BERTを用いて学習・分類を行います。

環境

OS 使用言語 Editor
Windows10 Python3.9 VScode

学習データ

Twitterから分野ごとにラベル付与を行います。
データ数は計4964個 ラベル数は5種類
学習データは以下の表のように整形します

本文 スポーツ 政治 経済 エンタメ 科学
ラグビーのユニフォームくれるの?!・・・ 1 0 0 0 0
・・・ 1 0 0 0 0
戦闘機やミサイルを買うお金(防衛費)は、・・・ 0 1 0 0 0
・・・ 0 1 0 0 0
二択で迷った企業、現職よりは仕事内容大変だけど、・・・ 0 0 1 0 0
・・・ 0 0 1 0 0
みんながやっぱ音楽系っしょになってた・・・ 0 0 0 1 0
・・・ 0 0 0 1 0
今ディスカバリーチャンネルで月開発に・・・ 0 0 0 0 1
・・・ 0 0 0 0 1
データ数 989 996 994 993 991

コード

コードの解説は以下の記事を参照してください。

BERTの事前学習モデルの用意

日本語の事前学習モデルは東北大学が開発したモデルを使用しました。
汎用性が高いのが特徴です。

全体のコード

from operator import index
import random
import tensorboard
import torch
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertModel
import pytorch_lightning as pl
import pandas as pd
from transformers import ElectraForPreTraining, ElectraTokenizerFast

from sklearn.metrics import accuracy_score
import math

# 日本語の事前学習モデル
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'

label_NUM = 5

class BertForSequenceClassificationMultiLabel(torch.nn.Module):
    
    def __init__(self, model_name, num_labels):
        super().__init__()
        # BertModelのロード
        self.bert = BertModel.from_pretrained(model_name) 
        # 線形変換を初期化しておく
        self.linear = torch.nn.Linear(
            self.bert.config.hidden_size, num_labels
        ) 

    def forward(
        self, 
        input_ids=None, 
        attention_mask=None, 
        token_type_ids=None, 
        labels=None
    ):
        # データを入力しBERTの最終層の出力を得る。
        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids)
        last_hidden_state = bert_output.last_hidden_state
        
        # [PAD]以外のトークンで隠れ状態の平均をとる
        averaged_hidden_state = \
            (last_hidden_state*attention_mask.unsqueeze(-1)).sum(1) \
            / attention_mask.sum(1, keepdim=True)
        
        # 線形変換
        scores = self.linear(averaged_hidden_state) 
        
        # 出力の形式を整える。
        output = {'logits': scores}

        # labelsが入力に含まれていたら、損失を計算し出力する。
        if labels is not None: 
            loss = torch.nn.BCEWithLogitsLoss()(scores, labels.float())
            output['loss'] = loss
            
        # 属性でアクセスできるようにする。
        output = type('bert_output', (object,), output) 

        return output

#ファインチューニングするためのデータの前処理
df1 = pd.read_excel('train_data.xlsx',)
finetuning_labels_lists = df1.drop(['入力文'],axis=1).values.tolist()

dataset=[]
index = 0
for sentences in df1['入力文']:
    sentence = sentences
    finetuning_labels_list = finetuning_labels_lists[index]
    sample = {'text':sentence , 'labels':finetuning_labels_list}
    dataset.append(sample)
    index+=1

# トークナイザのロード
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)

# 各データの形式を整える
max_length = 512

dataset_for_loader = []
for sample in dataset:
    text = sample['text']
    labels = sample['labels']
    encoding = tokenizer(
        text,
        max_length=max_length,
        padding='max_length',
        truncation=True
    )
    encoding['labels'] = labels
    encoding = { k: torch.tensor(v) for k, v in encoding.items() }
    dataset_for_loader.append(encoding)

# データセットの分割
random.shuffle(dataset_for_loader) 
n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)
dataset_train = dataset_for_loader[:n_train] # 学習データ
dataset_val = dataset_for_loader[n_train:n_train+n_val] # 検証データ
dataset_test = dataset_for_loader[n_train+n_val:] # テストデータ

# データセットからデータローダを作成
dataloader_train = DataLoader(
    dataset_train, batch_size=4, shuffle=True
) 
dataloader_val = DataLoader(dataset_val, batch_size=64)
dataloader_test = DataLoader(dataset_test, batch_size=64)

class BertForSequenceClassificationMultiLabel_pl(pl.LightningModule):

    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters() 
        self.bert_scml = BertForSequenceClassificationMultiLabel(
            model_name, num_labels=num_labels
        ) 

    def training_step(self, batch, batch_idx):
        output = self.bert_scml(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss
        
    def validation_step(self, batch, batch_idx):
        output = self.bert_scml(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)

    def test_step(self, batch, batch_idx):
        labels = batch.pop('labels')
        output = self.bert_scml(**batch)
        scores = output.logits
        labels_predicted = ( scores > 0 ).int()
        num_correct = ( labels_predicted == labels ).all(-1).sum().item()
        accuracy = num_correct/scores.size(0)
        self.log('accuracy', accuracy)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/',
)

trainer = pl.Trainer(
    gpus=1, 
    max_epochs=10,
    callbacks = [checkpoint]
)

model = BertForSequenceClassificationMultiLabel_pl(
    MODEL_NAME, 
    num_labels=label_NUM, 
    lr=1e-5
)

trainer.fit(model, dataloader_train, dataloader_val)
test = trainer.test(dataloaders=dataloader_test)
print(f'Accuracy: {test[0]["accuracy"]:.10f}')

test_df = pd.read_excel('test_data.xlsx',)
test_texts = test_df['入力文']
test_label = test_df.drop(['入力文'],axis=1).values.tolist()

# モデルのロード
best_model_path = checkpoint.best_model_path
model = BertForSequenceClassificationMultiLabel_pl.load_from_checkpoint(best_model_path)
bert_scml = model.bert_scml.cuda()

pre_label_lists = []
def classfy(texts):
    for text in texts:
        encoding = tokenizer(
        text, 
        padding = 'longest',
        return_tensors='pt'
        )   
        encoding = { k: v.cuda() for k, v in encoding.items() }
        # BERTへデータを入力し分類スコアを得る。
        with torch.no_grad():
            output = bert_scml(**encoding)
        scores = output.logits
        pre_max = max(scores.tolist()[0])
        max_index = scores.tolist()[0].index(pre_max)
        print(max_index)
        pre_label=[]
        pre_label.append(text)
        for i in range(0,label_NUM):
            if i == max_index:
                pre_label.append(1)
            else:
                pre_label.append(0)
        pre_label_lists.append(pre_label)
    
classfy(test_texts)    

df_pre_label_lists = pd.DataFrame(pre_label_lists)
print(df_pre_label_lists)
df_pre_label_lists.to_excel("output.xlsx")

テストデータについて

テストデータも学習データと同様ツイートを用います。
データ数は各分野10個ずつの50個を分類します。

作成したモデルで実際に分類してみた結果

正答率は92%となった
image.png

感想

ツイートの文章は学習するには質が悪いと思っていたので驚きの結果です。
不正解の文もありますが、ほかの分野に該当しそうな単語も含まれているので完全な分類は難しそうです。
やっぱりBERTの精度は高いなぁと感心しました。

さいごに

前回のエントリーシートの合否AI作成のリベンジとして実験を行いました。
ツイートのデータで学習を行っても、いい結果が得られて満足してます。
しかし、個人であ収集できるデータ数や質は限度があると感じたので、マーケティングなどのデータ分析に力を入れているような企業に勤めてみたいと思いました。
今後はgithubにもコードを載せたいと思います。

4
4
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4