14
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

BERT+FastAPIを用いた感情分類器をコンテナ化する

Last updated at Posted at 2021-03-18

なんの記事?

BERTを用いた日本語感情分類器を作製しました。

実装モデル

過去の記事でも書いたような、非常に単純なテキスト分類モデルです。

今回は3値分類のモデル(Positive/Negative/Neutral)を作製しました。

probs で1文が各クラスに相当する確率が返却されます。


'''
Model classes
'''
import torch
import pdb
import torch.nn as nn
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.models import Model
from overrides import overrides
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy
from torch.nn.functional import softmax

class PosNegClassifier(Model):
    def __init__(self, args,
                 mention_encoder: Seq2VecEncoder,
                 num_label: int,
                 vocab):
        super().__init__(vocab)
        self.args = args
        self.mention_encoder = mention_encoder
        self.accuracy = BooleanAccuracy()
        self.BCEWloss = nn.BCEWithLogitsLoss()
        self.accuracy = CategoricalAccuracy()
        self.loss = nn.CrossEntropyLoss()
        self.linear_for_classify = nn.Linear(self.mention_encoder.get_output_dim(), num_label)

    def forward(self, context,
                mention_uniq_id: torch.Tensor = None,
                label: torch.Tensor = None):
        emb = self.mention_encoder(context)
        scores = self.linear_for_classify(emb)
        probs = softmax(scores, dim=1)
        output = {}
        if label is not None:
            loss = self.loss(scores, label)
            self.accuracy(probs, label)
            output['loss'] = loss
            output['logits'] = scores
            output['mention_uniq_id'] = mention_uniq_id

        output['encoded_embeddings'] = emb
        output['probs'] = probs
        return output

    @overrides
    def get_metrics(self, reset: bool = False):
        return {"accuracy": self.accuracy.get_metric(reset)}

FastAPIを用いたコンテナ化

今回はこのモデルをコンテナの中に送り、コンテナを叩いてレスポンスが返却されるようにします。

予めFastAPIを起動する為に必要なDockerfileを作製します。

今回はcondaベースで作製しました。

Dockerfile

FROM continuumio/anaconda3:2019.03
RUN apt-get update && apt-get install -y \
    libfreetype6-dev \
    libjpeg62-turbo-dev \
    git \
    build-essential
RUN pip install --upgrade pip && \
    pip install autopep8
ARG project_dir=/projects/
WORKDIR $project_dir
ADD requirements.txt .
RUN pip install -r requirements.txt
COPY . $project_dir
CMD ["uvicorn","app:app","--reload", "--host", "0.0.0.0" ,"--port","8000", "--log-level", "trace"]

モデル自体はgpu上で訓練しますが、その後cpuに移しているので、cuda関連のパッケージは不要です。

このイメージをビルドして、コンテナを建てます。
$ docker build -t jsa:latest .
$ docker run -d -itd -p 8000:8000 jsa

ここまででコンテナを建てることが出来ました。

実際の解析結果

コンテナに向かってcurlしてみます。

$ curl -X 'POST' 'http://localhost:8000/sentiment/' -H 'accept: application/json' \
       -H 'Content-Type: application/json' \
       -d '{
            "sentence": "今日はいい天気"
           }'

>> {"probs":
        {
         "neutral":0.8089876174926758,
         "negative":0.015650086104869843,
         "positive":0.17536230385303497
         }
    }

それぞれのクラスの確率が返却されました。

課題点

何個か挙げます。

コンテナの軽量化

今回はconda+gpu上でモデルを訓練し、それをコンテナに送り込んでいます。加えてallennlp周辺のパッケージまで含めて、約3GB程度のイメージとなっています。

学習データのクロール

今回使用したデータは携帯電話に関するtwitter上の反応にアノテーションを行ったものです。多様なPositive/Negative/Neutral文に対応するためには、携帯ドメイン以外のツイートのクロールが考えられます。

ソースコード

14
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?