LoginSignup
4
8

More than 3 years have passed since last update.

日本語BERTを用いた記事分類

Last updated at Posted at 2021-03-06

なんの記事?

日本語BERTを用いて、Livedoorコーパスの文書分類タスクを解くモデルをさっと作ってみたので、その紹介です。

また、このモデルを用いてこの後に、記事推薦も実装予定です。
(追記:2021-03-14 続編を書きました。)

ソースコードも添付しています。よければ併せて御覧ください。

また、最新のAllenNLPおよびTransformersも併せて利用しています。

AllenNLPについて

AllenNLPはPytorchベースの自然言語処理のフレームワークです。今回は文書分類モデルをAllenNLPを利用しつつ、作ってみたいと思います。

作製したモデル

image.png

非常にシンプルなモデルです。

データ読み出し部分

タイトルと記事本文を特殊トークンで結合しています。この手法はこの論文この論文などでもよく見られる手法です。

    @overrides
    def text_to_instance(self, data=None) -> Instance:
        tokenized = [Token('[CLS]')]
        tokenized += [Token(split_token) for split_token in self.custom_tokenizer_class.tokenize(
                                          txt=data['title'])][:self.config.max_title_length]
        tokenized += [Token('[unused0]')]
        tokenized += [Token(split_token) for split_token in self.custom_tokenizer_class.tokenize(
                                          txt=data['caption'])][:self.config.max_caption_length]
        tokenized += [Token('[SEP]')]
        context_field = TextField(tokenized, self.token_indexers)
        fields = {"context": context_field}

        fields['label'] = LabelField(data['class'])

        return Instance(fields)

LabelField にテキスト分類ラベルを格納すると、AllenNLP側でよしなにindex付けとラベルの対応を行ってくれます。便利ですね。

今回は記事のタイトルと、タイトルの後ろに付随する代表センテンスを特徴として利用してみました。

[unused0]トークンをタイトルと記事の接合に使用しています。

モデル部分

class TitleAndCaptionClassifier(Model):
    def __init__(self, args,
                 mention_encoder: Seq2VecEncoder,
                 num_label: int,
                 vocab):
        super().__init__(vocab)
        self.args = args
        self.mention_encoder = mention_encoder
        self.accuracy = CategoricalAccuracy()
        self.loss = nn.CrossEntropyLoss()
        self.linear_for_classify = nn.Linear(self.mention_encoder.get_output_dim(), num_label)

    def forward(self, context, label):
        emb = self.mention_encoder(context)
        scores = self.linear_for_classify(emb)
        probs = softmax(scores, dim=1)
        loss = self.loss(scores, label)
        output = {'loss': loss}
        output['logits'] = scores
        output['probs'] = probs
        self.accuracy(probs, label)

        output['encoded_embeddings'] = emb

        return output

    @overrides
    def get_metrics(self, reset: bool = False):
        return {"accuracy": self.accuracy.get_metric(reset)}

分類モデルを作製し、クロスエントロピー損失を取る、非常にシンプルなモデルです。

output['emcoded_embeddings'] はモデル自体の予測とは別に後で使用するために、モデルのforwardで吐き出させるようにしています。

エンコーダ部分


class Pooler_for_mention(Seq2VecEncoder):
    def __init__(self, args, word_embedder):
        super(Pooler_for_mention, self).__init__()
        self.args = args
        self.huggingface_nameloader()
        self.bertpooler_sec2vec = BertPooler(pretrained_model=self.bert_weight_filepath)
        self.word_embedder = word_embedder
        self.word_embedding_dropout = nn.Dropout(self.args.word_embedding_dropout)

    def huggingface_nameloader(self):
        if self.args.bert_name == 'japanese-bert':
            self.bert_weight_filepath = 'cl-tohoku/bert-base-japanese'
        else:
            self.bert_weight_filepath = 'dummy'
            print('Currently not supported', self.args.bert_name)
            exit()

    def forward(self, contextualized_mention):
        mask_sent = get_text_field_mask(contextualized_mention)
        mention_emb = self.word_embedder(contextualized_mention)
        mention_emb = self.word_embedding_dropout(mention_emb)
        mention_emb = self.bertpooler_sec2vec(mention_emb, mask_sent)

        return mention_emb

    @overrides
    def get_output_dim(self):
        return 768

おなじみの、[CLS]トークンに相当する埋め込みを取得するだけのエンコーダになります。

これらを組み合わせて最初の図のモデルになります。

実験結果

epochを完全に回していませんが、5epochで dev acc. ~ 85%, test acc. ~ 83% でした。
学習がうまく出来ていますね。

ソースコード

こちらに載せました。

記事埋め込みの吐き出し

次回の記事の為に、少しここで先に準備をしておきましょう。
各記事の埋め込みを訓練後のモデルから吐き出す為に、ラベル無しデータがModelを通過できるよう
モデルを書き換えます。

    def forward(self, context,
                mention_uniq_id: torch.Tensor = None,
                label: torch.Tensor = None):
        emb = self.mention_encoder(context)
        scores = self.linear_for_classify(emb)
        probs = softmax(scores, dim=1)
        output = {}
        if label is not None:
            loss = self.loss(scores, label)
            self.accuracy(probs, label)
            output['loss'] = loss
            output['logits'] = scores
            output['probs'] = probs
            output['mention_uniq_id'] = mention_uniq_id

        output['encoded_embeddings'] = emb
        return output

ラベルを持たないただの記事タイトルのみがモデルが入ってきた場合も、この書き換えにより、エラーを出さずに埋め込みを吐き出すことが可能になります。

吐き出し部分については、AllenNLPのチュートリアルを参考に実装しました。これにより、任意のテキスト(ここでは記事のタイトルや検索クエリ)に対して、今回訓練したモデルからの出力埋め込みを手に入れることが出来ます。


from typing import Dict, Iterable, List, Tuple
from allennlp.modules.token_embedders import PretrainedTransformerEmbedder

from allennlp.data import (
    DataLoader,
    DatasetReader,
    Instance,
    Vocabulary,
    TextFieldTensors,
)
from allennlp.data.data_loaders import SimpleDataLoader
from allennlp.models import Model
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
import pdb
from allennlp.predictors import Predictor
from allennlp.common.util import JsonDict
from allennlp.data.samplers import BucketBatchSampler

class EmbeddingEncoder(Predictor):
    def predict(self, sentence: str) -> JsonDict:
        # This method is implemented in the base class.
        return self.predict_json({"context": sentence})

    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        context = json_dict["context"]
        return self._dataset_reader.text_to_instance(mention_uniq_id=None,
                                                     data={'title': context})

実際に吐き出している main.py の部分

    embedding_encoder = EmbeddingEncoder(model, dsr)
    res = embedding_encoder.predict('test emb')

次回はこのコードと、今回得られたモデルを用いて記事推薦モデルを実践する予定です。

4
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
8