LoginSignup
33
27

BERTで自殺ツイートを検出する試み ~④BERTでの判別~

Last updated at Posted at 2023-03-05

はじめに

みなさん、こんにちは。これまで、自殺ツイートを検出するまでのデータセット作成編やBERTを試す前にナイーブベイズを試すなど試行錯誤してきましたが、やっと本題のBERTでの自殺ツイート検出に至ることができたので、記事にまとめました。
▼これまでの記事はこちら

何故こんなことをやろうかと思ったかというと、前の記事にも書いたのですが、学生時代に友人を自殺で亡くし、友人の苦悩に気づけなかった自分に後悔があり、技術の力で悩んでいる人を救いたいという想いから初心者ながらBERTを動かしてみました。自分自身もそうですが、自分の弱みや悩みを人に伝えるということはなかなか難しく、本当に追い詰められたときに発したSOSではもう手遅れということが多々あります。そうなる前に、社会が困っている人に手を差し伸べて、社会が人を助ける仕組みを作りたい と考え、まずは行動してみようと思い始めたのがきっかけです。
こんなものを作りたい.png

本記事の内容
・BERTのファインチューニングにより自殺に関連するツイートかどうかを判別するAIを作成
・自殺が多いと言われている3月から4月にかけてのTweetを5,000件収集し、自身で「自殺かどうか」判別しデータセット作成&モデル構築&性能評価を行いました。
・結果、正解率90.5%、AUC 0.884となかなかの精度のモデルができました。
・自殺かどうかを検出することで、カウンセリングを紹介したり、追い詰められているけども周りに言えない人々を救うことができるのでは・・・と思っています。

完成物

自殺ツイートの場合

ツイートを入力してください
"生きてても意味ないし。早く死にたい。誰か殺してくれないかな。"
判定結果
自殺関連ツイートです。一人で悩まないでください。カウンセリングを紹介します。
prob:0.9997

自殺に無関係なツイートの場合

ツイートを入力してください
"ゲーム課金しすぎた。マジ死にたいんですけどwww"
判定結果
自殺ツイートではありません。何か不安なことがあれば相談してくださいね。
prob:0.9999

結構、きちんと判別できました!

実行環境

Google Colaboratory
Python 3.7.13
Pytorch 1.11.0+cu113
Numpy 1.21.6
transformers 4.20.1
fugashi 1.1.2
unidic-lite 1.0.8

こちらの記事を参考にさせて頂きました。おかげさまで初学者でもやりたいことに辿り着くことができました。感謝の気持ちでいっぱいです。

データセット作成

データセット作成、ラベル付け、困ったことなどはこちらの記事でまとめています。

BERTでの自殺ツイート判別モデルの作成

BERT(Bidirectional Encoder Representations from Transformers)は、Googleが2018年に発表した自然言語処理における言語モデルの一つです。BERTは、大量のテキストデータを使用して、言語処理のタスクを学習することができます。BERTは、文章の中の単語の意味を理解するために、前後の単語を考慮する双方向性を持っています。また、事前学習と呼ばれる学習手法を用いて、大量のテキストデータからパラメータを学習し、その学習済みのパラメータを別のタスクに転移することができます。BERTをファインチューニングすることで初学者でもあらゆるタスクに活用することができます。

実装

データセット読み込み・下準備

from google.colab import drive
drive.mount('/content/drive')

# データセット読込
import pandas as pd
dataset = pd.read_csv('./drive/MyDrive/suicide_detection_bert/dataset/train.csv')

# 2クラスに分割
dataset_pos = dataset[dataset['suicide']==1] # 自殺関連
dataset_neg = dataset[dataset['suicide']==0] # 自殺関連でない
#自殺=1 ツイート内容 
dataset_pos.head(5)

自殺ラベル=1のツイートは、本当に深刻そうな悩みがうかがえます。
tweet.png

#自殺=0 ツイート内容 
dataset_neg.head(5)

こちらはツイートに「死にたい」というワードが含まれるものの、自殺との関連性が低い内容となっています。
tweet2.png

続いて、必要なツールのインストール、ライブラリのインポートを行います。

!pip install -q transformers # Hugging FaceのAPI
!pip install -q fugashi # 形態素解析ツール
!pip install -q unidic-lite # 形態素解析用の辞書
# モジュールのインポート
import os
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score

from transformers import AutoTokenizer, BertForSequenceClassification
from transformers import TrainingArguments, Trainer

次にデータを分割します。

# 訓練とテストに分割
dataset_pos_train, dataset_pos_test = train_test_split(dataset_pos, train_size=0.5, random_state=0)
dataset_neg_train, dataset_neg_test = train_test_split(dataset_neg, train_size=0.5, random_state=0)

# それぞれのサンプル数を確認
print('positive:')
print(len(dataset_pos_train), len(dataset_pos_test))
print('negative:')
print(len(dataset_neg_train), len(dataset_neg_test))
出力結果
positive:
249 249
negative:
2500 2501

自殺ツイートが全体の10%と、かなり不均衡ではありますが、ひとまずこれでも判別できるのか試してみます。

def extract_dataset(pos, neg, random_state=0):
    '''
    2クラスのデータを結合し、
    データセットとなるテキストとラベルを準備する
    --- Inputs ---
    pos, neg : pandas.DataFrame
        データセットから選択した各クラスのデータ
    --- Returns ---
    texts : list
        テキストのリスト
    labels : list
        pos = 1, neg = 0としてtextsの位置に対応づけたリスト
    '''
    # データフレームを結合し、ラベルの列を追加
    posneg = pd.concat([pos, neg])
    posneg['label'] = [1 for _ in range(len(pos))] + [0 for _ in range(len(neg))]

    # シャッフルしておく(transformers側でされているのか不明なため)
    # 適当に乱数を発生させ、その大小で並べ替える
    np.random.seed(random_state)
    posneg['random_number'] = np.random.rand(len(posneg))
    posneg_sorted = posneg.sort_values('random_number')

    # 値の取り出し
    texts = posneg_sorted['tweet'].to_list()
    labels = posneg_sorted['label'].to_list()
    return texts, labels

texts_train, labels_train = extract_dataset(dataset_pos_train, dataset_neg_train)
texts_test, labels_test = extract_dataset(dataset_pos_test, dataset_neg_test)
print(len(texts_train), len(texts_test))
出力結果
2749 2750

訓練データは2749、テストデータは2750となりました。

GPUセットアップ、データのトークナイズ

続いて、BERTがテキストを読める形に設定します。トークン最大長は512まで設定できますが、重くなりすぎるので128にしました。

# GPUを利用する
device = "cuda:0"

# 東北大BERT
model_name = "cl-tohoku/bert-large-japanese"

# トークナイザ。モデルに合ったものが自動で選択される
tokenizer = AutoTokenizer.from_pretrained(model_name)

# tokenize
token_maxlen = 128 # トークン最大長

tokenized_train = tokenizer(texts_train, return_tensors='pt', padding=True, truncation=True, max_length=token_maxlen).to(device)
tokenized_test = tokenizer(texts_test, return_tensors='pt', padding=True, truncation=True, max_length=token_maxlen).to(device)

# ログ記録用にテストデータを縮小したものも作っておく
tokenized_val = tokenizer(texts_test[:500], return_tensors='pt', padding=True, truncation=True, max_length=token_maxlen).to(device)
class TweetDataset(torch.utils.data.Dataset):
    '''
    ツイートのデータセットクラス
    --- Attributes ---
    encodings : tokenizerによる処理後のデータ
    labels : 0 or 1のラベルデータ
    '''
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# それぞれデータセットにする
train_dataset = TweetDataset(tokenized_train, labels_train)
test_dataset = TweetDataset(tokenized_test, labels_test)
val_dataset = TweetDataset(tokenized_val, labels_test[:500])

事前学習済みモデルのロードとファインチューニング

事前学習済みモデルを読み込みます。

# 学習済みモデル(東北大BERT)をロード
model = BertForSequenceClassification.from_pretrained(model_name).to(device)
#学習中のvalidationにAccuracyなどの評価関数を表示させたいので、compute_metrics()を作成してTrainerクラスに渡せるようにしておく。
softmax = torch.nn.Softmax(1)

def compute_metrics(pred):
    '''
    学習中に実行される評価関数
    '''

    # 各metricsが扱えるベクトルに変換
    labels = pred.label_ids # 正解ラベル
    outputs = softmax(torch.Tensor(pred.predictions))
    probabilities = outputs[:,1] # 予測値(0~1の連続値)
    # preds = outputs.argmax(1) # 予測値(0 or 1のバイナリ)
    preds = (probabilities > 0.2).int()  # 閾値0.2 閾値を下げてRecallを上げる
    # metricsを計算
    acc = accuracy_score(labels, preds)
    pre = precision_score(labels, preds)
    rec = recall_score(labels, preds)
    auc = roc_auc_score(labels, probabilities)
    f1 = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
    # 辞書として返す
    return {
        'accuracy': acc,
        'precision': pre,
        'recall': rec,
        'auc': auc
    }

いよいよモデルを訓練します。

# 訓練時のパラメータ
training_args = TrainingArguments(
    output_dir='./results',          # 結果の出力先
    num_train_epochs=3,              # エポック数
    per_device_train_batch_size=4,   # バッチサイズ(訓練時)
    per_device_eval_batch_size=4,    # バッチサイズ(検証時)
    weight_decay=0.01,               # 重み減衰率
    save_total_limit=1,              # チェックポイントを保存する数(?)
    learning_rate=5e-6,              # 学習率
    dataloader_pin_memory=False,
    evaluation_strategy="steps",
    logging_steps=200,
    logging_dir='./logs'
)
# モデルの訓練を行う
trainer = Trainer(
    model=model, 
    args=training_args, 
    train_dataset=train_dataset, 
    eval_dataset=test_dataset, 
    compute_metrics=compute_metrics
)

trainer.train()

こんな結果となりました。BERTは過学習しやすいのが悩ましいです。データセットの正解データを再翻訳で5倍に増やしたりもしましたが、バリエーションが増えず、元データの方が結果が良くなりました。
結果.png

最終的な評価は・・・

# 最終的な評価
trainer.evaluate(eval_dataset=test_dataset)
出力結果
***** Running Evaluation *****
  Num examples = 2750
  Batch size = 4

{'eval_accuracy': 0.9050909090909091,
 'eval_precision': 0.47580645161290325,
 'eval_recall': 0.4738955823293173,
 'eval_auc': 0.8847898591567389,
 }

正解率0.905、aucが0.884となかなか良い精度なのではないでしょうか!?

判定結果

自殺ツイートの場合
判定③.png
死にたいと明記されていなくても、「飛び降りたい」という文脈を考慮して自殺ツイートと正しく判定されています。

自殺に関係ないツイートの場合
判定④.png

こちらは、「死んだ」というネガティブワードが入っていても、自殺ツイートではないと正しく判定されています。BERT恐るべし。

正しく判定されなかったツイート
リスカ.png
「リスカ」は未知語のためか、自殺に関連性が高くても正しく判定されませんでした。

他のモデルとの精度比較

3つの手法を比較することで、BERTの素晴らしさを実感することができました。
最初からBERTだったら、BERTの恩恵に気づくことが出来なかっただろうなと思います。

AUC
ルールベース 0.205
ナイーブベイズ 0.680
BERT 0.884

まとめ

自殺ツイート検出のために、自作データセット作成から、古典的手法を試したり、BERTでファインチューニングしたりと、実際に自殺ツイートを判別することができました。

一連の経験を通して、初学者でも想いと行動力さえあれば、やりたいことを実現できるということを体現することができました。

私が本格的に自然言語処理を独学で学んだのは今年に入ってからですが、自然言語処理は学べば学ぶほど奥深く、やりたいことを実現するために必要な周辺知識に絞って習得し、目標から逆算して手を動かしていくということが非常に有効だと感じました。

これまでずっと、社会課題の解決のために技術を習得して、役に立つものを創りたいと思ってばかりで行動できていませんでしたが、ようやく形にすることができました。
この経験を通して、自分がやってみたいことはテクノロジーによる社会課題の解決だということが明確になり、今後もツイートを検出して終わりではなく、本当に悩んでいる人、困っている人を助けられるように、これからも自分にできることを探し実現していきます。

ブログへ移行中⇒https://yurufuwa-ai-engineer.com/

33
27
6

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
33
27