Help us understand the problem. What is going on with this article?

Django REST frameworkでBERTを用いたネガポジ推論結果を返すAPIの作成

この投稿は 「Django Advent Calendar 2019 - Qiita」 の20日目の記事です。

sinyです。

この記事では、Django REST frameworkBERTモデルを用いたネガポジ推論結果を返すRestAPIの作成にチャレンジした内容についてまとめました。

※DRFやBERTは今回初めて実装したので、誤っている部分や「もっとこうしたほうがよいよ!」といった点があればご指摘いただけると幸いです。

はじめに

この記事では、Djangoが主目的のためBERTモデル実装に関わる部分の説明は行いません。
この記事内で利用している日本語データセットを用いたBERTモデルの実装、学習については自然言語処理 Advent Calendar 2019 25日目(BERTを用いたネガポジ分類機の作成)」をご覧ください。

まず最初に、今回想定しているDRF環境の全体処理の概要図です。

drf_bert.png

実施していること自体はシンプルで、入力データとして与えられた文章がネガティブかポジティブかをBERTモデルで推論(2値分類)し、その結果をクライアント側に返すREST APIです。

※概要図ではAzureのデータサイエンス用環境(DSVM)を使っていますが、今回はローカルの開発サーバを使って構築していきます。

【APIのデモ動画】
BERT_DRF.gif

この記事で実装するモジュールはこちらのgitリポジトリにありますので、適宜ダウンロードしてお使いください。

また、BERTに関しては書籍「つくりながら学ぶ! PyTorchによる発展ディープラーニング」 を参考に作成しました。
上記書籍では英語データをベースとしたBERTモデルのネガポジ分類になっていたため、書籍をベースに日本語データセットでネガポジ分類できるように改良しています。

目次

  1. 前提
  2. 環境構築
  3. Django REST frameworkの作成
  4. REST APIを使って推論
  5. 簡易ツール
  6. まとめ
  7. 参考書籍

1.前提

この記事の内容は以下の環境で動作を確認しています。

ローカル環境

項目   意味
OS  Windows10のUbuntu
BERTモデル        京都大学が公開しているpytorch-pretrained-BERTモデルをベースにファインチューニングを行う。
形態素解析 Juman++ (v2.0.0-rc2) or (v2.0.0-rc3)
Django    2.2.5
djangorestframework    3.10.3

2.環境構築

今回はWindows10のUbuntu環境で動作するDRFを構築していきます。
まず最初にcondaで仮想環境の作成と必要なモジュールのインストールを行います。

各種モジュールのインストール

conda create -n drf python=3.6
conda activate drf
conda install pytorch=0.4 torchvision -c pytorch
conda install pytorch=0.4 torchvision cudatoolkit -c pytorch
conda install pandas scikit-learn django

condaでうまく入らないものはpipでインストールします。

pip install mojimoji
pip install attrdict
pip install torchtext
pip install pyknp
pip install djangorestframework

Juman++のインストール

今回利用するBERT日本語Pretrainedモデルは、入力テキストにJuman++ (v2.0.0-rc2)で形態素解析を行っていますので、本記事でも形態素解析ツールをJuman++に合わせます。

Juman++の導入手順は別記事でまとめていますので、以下を参照ください。

[JUMAN++の導入手順まとめ]https://sinyblog.com/deaplearning/juman/

Juman++導入後にローカル環境で以下の通り形態素解析出る状態になっていればOKです。

# JUMANの動作確認                  
from pyknp import Juman                 
text = "自然言語処理について学習中です。"                   
juman = Juman()                 
result =juman.analysis(text)                    
result = [mrph.midasi for mrph in result.mrph_list()]                   
print(text)
自然言語処理について学習中です。        
print(result)
['自然', '言語', '処理', 'に', 'ついて', '学習', '中', 'です', '。']            

3. Django REST frameworkの作成

プロジェクト作成

まずdjangoのプロジェクトを作成します。(プロジェクト名:drf)

django-admin startproject drf

アプリケーション作成

続いてアプリケーションを作成します(アプリケーション名:appv1)

cd drf
python manage.py startapp appv1

settings.pyのカスタマイズ

settings.pyのINSTALLED_APPSにrest_frameworkとアプリケーション(appv1)を追加します。

INSTALLED_APPS = [
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',
    'rest_framework',                #add
    'appv1.apps.Appv1Config',        #add

]

BERT関連モジュールの配置

アプリケーションフォルダ(appv1)直下に以下のフォルダを作成して指定の通りモジュールを配置します。

フォルダ名   配置モジュール  用途
vocab   vocab.txt  BERTの語録辞書ファイル
weights     bert_config.json BERT設定ファイル
weights            pytorch_model.bin                京都大学公開のHPからダウンロードしたファイル(学習済みモデル)              
weights            bert_fine_tuning_chABSA.pth  BERTファインチューニング学習済みモデル
data **.tsvファイル4つ 学習データ、テストデータ等

※学習済みモデルファイル(bert_fine_tuning_chABSA.pth)を利用してとりあえず動かしてみたい方はこちらから学習済みモデルファイルをダウンロードして配置してください。
※vocab.txt、bert_config.json、pytorch_model.binは京都大学公開のHPからダウンロードできます。

アプリケーション(appv1)直下に以下のファイルを配置します。

ファイル名   意味
config.py  各種設定ファイル
dataloader.py torchtextデータローダー生成用ファイル
predict.py 推論用
tokenizer.py   BERT単語分割関連のシェル
bert.py   BERTモデルの定義

Djangoのシェルモードを起動して、推論時に利用する語録辞書データファイルを生成しておきます。

python manage.py shell
from appv1.config import *
from appv1.predict import create_vocab_text
TEXT = create_vocab_text()

上記を実行するとappv1/data/text.pklが生成されます。

全体のディレクトリ構成は以下の通りです。

├─drf
│  │  db.sqlite3
│  │  manage.py
│  │
│  ├─appv1
│  │  │  admin.py
│  │  │  apps.py
│  │  │  bert.py           # BERTモデルの定義
│  │  │  config.py         # 各種設定ファイル
│  │  │  dataloader.py     # torchtextデータローダー生成用ファイル
│  │  │  models.py
│  │  │  predict.py        # 推論用
│  │  │  serializers.py    # シリアライザ
│  │  │  tests.py
│  │  │  tokenizer.py      # BERT単語分割関連のシェル
│  │  │  views.py
│  │  ├─data

│  │  │      test_dumy.tsv  # ダミーデータ
│  │  │      train_dumy.tsv # ダミーデータ
│  │  │      text.pkl    # 推論時に使う語録データ
│  │  │
│  │  ├─vocab
│  │  │      vocab.txt   # Bert語録データ
│  │  │
│  │  ├─weights
│  │  │      bert_config.json
│  │  │      bert_fine_tuning_chABSA.pth  # ファインチューニング済みBertモデル
│  │  │      pytorch_model.bin
│  │  │
│  ├─drf
│  │  │  settings.py
│  │  │  urls.py
│  │  │  wsgi.py

推論の動作確認

django環境上でBERT学習済みモデルを使って推論処理が正常に動くことを確認しておきます。

Djangoのシェルモードを起動後に以下のコマンドを実行して、サンプルの文章データを与えてネガティブorポジティブ判定が行えることを確認します。

※各コマンドで実施している処理はコメントを参照。

python manage.py shell
-----------------------------------------------------------------------------
#ここからシェルモード
In [1]: from appv1.config import *
In [2]: from appv1.predict import predict2, create_vocab_text, build_bert_model
In [3]: from appv1.bert import get_config, BertModel,BertForchABSA, set_learned_params
In [4]: import torch
In [5]: config = get_config(file_path=BERT_CONFIG)  # BERTコンフィグ設定のロード
In [6]: net_bert = BertModel(config)   #BERTモデルの生成
In [7]: net_trained = BertForchABSA(net_bert)   #BERTモデルにネガポジ用分類機を結合
In [8]: net_trained.load_state_dict(torch.load(MODEL_FILE, map_location='cpu'))  #学習済み重みをロー
   ...: 
Out[8]: IncompatibleKeys(missing_keys=[], unexpected_keys=[])
In [9]: net_trained.eval()   #推論モードにセット
Out[9]:
BertForchABSA(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32006, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (selfattn): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )

 ~~出力結果が多いため一部主略~

    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (cls): Linear(in_features=768, out_features=2, bias=True)
)

In [10]: input_text = "損益面におきましては、経常収益は、貸出金利息や有価証券売却益の減少により、前期
    ...: 比72億73百万円減少の674億13百万円となりました"
In [11]: result = predict2(input_text, net_trained).numpy()[0]  #推論の実行(戻り値はネガティブorポジ
    ...: ティブ)
['[UNK]', '面', 'に', 'おき', 'まして', 'は', '、', '[UNK]', '収益', 'は', '、', '貸出', '金', '利息', 'や', '有価', '証券', '売却', '益', 'の', '減少', 'に', 'より', '、', '前期', '比', '[UNK]', '円', '減少', 'の', '[UNK]', '円', 'と', 'なり', 'ました']
[2, 1, 534, 8, 7779, 26207, 9, 6, 1, 7919, 9, 6, 15123, 306, 28611, 34, 27042, 4190, 3305, 8995, 5, 1586, 8, 52, 6, 4523, 2460, 1, 387, 1586, 5, 1, 387, 12, 105, 4561, 3]
In [12]: print(result)
0

上記の通りエラーなくresult変数に戻り値としてネガティブ(0) or ポジティブ(1)が返ってくれば動作OKです。

serializerの作成

続いて、DRFのシリアライザーを作成します。
なお、DRFにおけるシリアライズ、デシリアライズは以下のような処理です。

処理   意味
シリアライズ     JSON文字列などをDjangoモデルオブジェクトに変換すること 
デシリアライズ モデルオブジェクトからJSON形式などに変換すること

今回のケースではDjangoのモデルは扱わないためモデルオブジェクトへの変換は行いません。
代わりに、「ネガポジ判定したい入力文章を受け取り、BERTの学習済みモデルにインプットして推論結果をアウトプットする」というシリアライザを作成します。

DRFのシリアライザーには大きく以下の3種類がありますが、今回はモデルに依存しないカスタマイズな処理を実装したいので、rest_frameworkのSerializerクラスを利用します。

serializerの処理 意味
ModelSerializer        単一のモデルオブジェクトを利用 
Serializer 単一のリソースを扱う、またはモデルに依存しないカスタマイズな処理を実装する
ListSerializer 複数リソースを扱う

appv1直下にserializers.pyというファイルを作成し、以下のコードを追記します。

appv1\serializers.py

from rest_framework import serializers
from appv1.config import *  
from appv1.predict import predict2
from appv1.bert import get_config, BertModel,BertForchABSA  
import torch        


class BertPredictSerializer(serializers.Serializer):
    """BERTのネガポジ分類結果を得るシリアライザ"""

    input_text = serializers.CharField()
    neg_pos = serializers.SerializerMethodField()

    def get_neg_pos(self, obj):
        config = get_config(file_path=BERT_CONFIG)  #bertコンフィグファイルのロード
        net_bert = BertModel(config)  #BERTモデルの生成
        net_trained = BertForchABSA(net_bert) # #BERTモデルにネガポジ用分類機を結合
        net_trained.load_state_dict(torch.load(MODEL_FILE, map_location='cpu'))  #学習済みの重みをロード
        net_trained.eval()  #推論モードにセット
        label = predict2(obj['input_text'], net_trained).numpy()[0]  #推論結果を取得
        return label

まず、rest_frameworkのserializers.Serializerを承継してBertPredictSerializerというシリアライザを作成します。

class BertPredictSerializer(serializers.Serializer):
    """BERTのネガポジ分類結果を得るシリアライザ"""

    input_text = serializers.CharField()
    neg_pos = serializers.SerializerMethodField()

入力として文字列型のCharFieldを定義(input_text)しています。
アウトプットは動的な値(推論結果)になるため、メソッドの結果によってフィールドの値をきめることができるserializers.SerializerMethodField()を使いneg_posというフィールドを定義しています。

SerializerMethodField()を利用する場合は、適用されるメソッド名をget_ + フィールド名として定義します。
今回のケースではget_neg_posメソッドを定義します。

このようなシリアライザを定義すると、BertPredictSerializerを使ったAPIビューに入力文章をPOST送信した際にget_neg_posメソッドが実行されて処理結果がJSON文字列として返ってきます。

今回は入力文章データをBERT学習モデルに与えネガポジ判定結果を得たいので、get_neg_posメソッド内に以下の処理を記載しています。

    def get_neg_pos(self, obj):
        config = get_config(file_path=BERT_CONFIG)  #bertコンフィグファイルのロード
        net_bert = BertModel(config)  #BERTモデルの生成
        net_trained = BertForchABSA(net_bert) # #BERTモデルにネガポジ用分類機を結合
        net_trained.load_state_dict(torch.load(MODEL_FILE, map_location='cpu'))  #学習済みの重みをロード
        net_trained.eval()  #推論モードにセット
        label = predict2(obj['input_text'], net_trained).numpy()[0]  #推論結果を取得
        return label

実施していることはコメントを見ていただければ分かるかと思いますが、以下補足しておきます。

  • get_configメソッド:bert_config.jsonから読み込み、JSONの辞書変数をオブジェクト変数に変換
  • BertModelクラス:BERTモデルを生成するクラス
  • BertForchABSAクラス:BERTモデル(BertModel)にchABSAのネガ・ポジを判定する部分をつなげたモデル
  • predict2メソッド:入力された文章がネガティブかポジティブかを推論して0 or 1を返すメソッド

上記4つのメソッドクラスはbert.py内に定義しています。
net_trained.load_state_dictで学習済みモデルのパラメータをロードします。
predict2メソッドはpredict.py内に定義しており、第1引数に入力文章第2引数に学習済みモデルインスタンスを渡すと推論結果(0 or 1)を返すメソッドになっています。

以上でシリアライザの作成は完了です。

ビューの作成

DRFのビューはクラスベースビュー、関数ベースのどちらで実装してもOKです。

今回はGenericAPIViewというクラスベースビューを使ってみました。

from django.shortcuts import render
from rest_framework import generics, status
from rest_framework.response import Response
from appv1.serializers import BertPredictSerializer


class BertPredictAPIView(generics.GenericAPIView):
    """BERTネガポジ分類予測クラス"""
    serializer_class = BertPredictSerializer

    def post(self, request, *args, **kwargs):     

        serializer = self.get_serializer(request.data)
        return Response(serializer.data, status=status.HTTP_200_OK)

GenericAPIViewを承継してBertPredictAPIViewを定義しています。
serializer_class属性にserializer.py内で定義したシリアライズクラス(BertPredictSerialzer)を指定します。
次にBertPredictAPIView内にpostメソッドを定義し、内部でGenericAPIViewのget_serializerメソッドを使ってValidationに使用されるシリアライザインスタンスを取得します。
引数にはrequest.dataを指定します。

serializer = self.get_serializer(request.data)

request.dataは「POST」、「PUT」、および「PATCH」メソッドで機能し、任意のデータを処理してくれます。(request.POSTと似たような機能)
最後にResponseの引数にserializer.dataを指定してあげると推論結果を返すビューが完成します。

urls.pyの定義

最後にdrf\urlspy内にURLの定義を追加します。

from django.contrib import admin
from django.urls import path
from appv1 import views  #add


urlpatterns = [
    path('admin/', admin.site.urls),
    path('api/v1/predict/', views.BertPredictAPIView.as_view()),   #add
]

views.pyに定義したBertPredictAPIViewをapi/v1/predictというURLパターンとして定義しています。
これでRestAPIのエンドポイントとしてhttp://127.0.0.1/api/v1/predictを定義したことになります。

ここで一旦djagnoのマイグレーションを実行して、runserverを実行します。

python manage.py migrate
python manage.py runserver

その後、http://127.0.0.1:8000/api/v1/predict/にアクセスすると以下のような画面が表示されます。

drf_bert_02.png

画面下の「input text」にネガポジを判定したい文章を入力してPOSTボタンを押すとBERTのファインチューニング学習済みモデルを使って推論を行い結果をリターンしてくれます。

input textに「㈱東急コミュニティーにおいて管理ストックがマンション、ビルともに拡大し増収増益となりました」という文章を入力しPOSTボタンを押した結果が以下の画面です。

drf_bert_03.png

戻り値として、input_text(入力した文章)とネガポジを表すneg_pos(0 or 1)が返ってきます。

上記例では、推論結果としてポジティブ(=1)が返ってきています。

4.REST APIを使って推論

次に作成したREST APIをローカルクライアントからCALLして推論結果を受け取ってみます。
今回は簡易的にクライアントとサーバを同一PC内で実行します。

python manage.py runserverで開発サーバを起動した状態にしたら、別ウィンドウでDOS画面等を立ち上げpythonコマンドを実行してpythonコードを実行できる状態にして以下のコードを実行します。

import urllib.request
import urllib.parse
import json
def predict(input_text):
    URL = "http://127.0.0.1:8000/api/v1/predict/"
    values = {
        "format": "json",
        "input_text": input_text,
        }
    data = urllib.parse.urlencode({'input_text': input_text}).encode('utf-8')
    request = urllib.request.Request(URL, data)
    response = urllib.request.urlopen(request)
    result= json.loads(response.read())
    return result

input_text = "損益面におきましては、経常収益は、貸出金利息や有価証券売却益の減少により、前期比72億73百万円減少の674億13百万円となりました"
result = predict(input_text)
print(result)
{'input_text': '損益面におきましては、経常収益は、貸出金利息や有価証券売却益の減少により、前期比72億73百万円減少の674億13百万円となりました', 'neg_pos': 0}
print(result['input_text'])
損益面におきましては、経常収益は、貸出金利息や有価証券売却益の減少により、前期比7273百万円減少の67413百万円となりました
print(result['neg_pos'])
0

エンドポイントとして"http://127.0.0.1:8000/api/v1/predict/"を指定します。
辞書型変数valuesにformatとしてjson、入力データとしてinput_textを与えます。

あとはurllib.request.Requestを使ってエンドポイントに対してエンコードした文章データと一緒にPOSTメソッドで投げると、views.pyで定義したBertPredictAPIViewクラスが呼び出されシリアライズ→推論実行の流れで処理が動いていきます。
python側で辞書型データとして扱えるようにするためjson.loadsで処理結果を変換してあげます。

すると以下のように辞書型のデータに変換されるので、キー名(input_text, neg_pos)でほしい情報にアクセスできます。

{'input_text': '損益面におきましては、経常収益は、貸出金利息や有価証券売却益の減少により、前期比72億73百万円減少の674億13百万円となりました', 'neg_pos': 0}

5.簡易ツール

最後に大量のインプットデータを自動でREST APIに投げて推論結果を出力する簡易コマンドを作成します。
なお、以下で利用するcsvファイルとプログラムはgitリポジトリのdjango-drf-dl/drf/tools/配下にあります。

以下のような予測したいtest.csvファイルを用意します。(列名がINPUT)

bert_testdata.png

test.csvと同フォルダ内に以下のコードを記載したpredict.pyを配置します。

import pandas as pd
import numpy as np
import urllib.request
import urllib.parse
import json

def predict(input_text):
    URL = "http://127.0.0.1:8000/api/v1/predict/"
    values = {
        "format": "json",
        "input_text": input_text,
            }
    data = urllib.parse.urlencode({'input_text': input_text}).encode('utf-8')
    request = urllib.request.Request(URL, data)
    response = urllib.request.urlopen(request)
    result= json.loads(response.read())
    return result['neg_pos'][1] 

if __name__ == '__main__':
    print("Start if __name__ == '__main__'")
    print('load csv file ....')
    df = pd.read_csv("test.csv", engine="python", encoding="utf-8-sig")
    df["PREDICT"] = np.nan   #予測列を追加
    print('Getting prediction results ....')
    for index, row in df.iterrows():
        df.at[index, "PREDICT"] = predict(row['INPUT'])
    print('save results to csv file')
    df.to_csv("predicted_test .csv", encoding="utf-8-sig", index=False)
    print('Processing terminated normally.')

DOS画面などを起動して以下のコマンドを実行するとtest.csvを1行ずつ読み込みREST APIに投げて推論結果を受け取ります。

python predict.py
----------------------------------
Start if __name__ == '__main__'
load csv file ....
Getting prediction results ....
save results to csv file
Processing terminated normally.

最後のデータまで完了すると推論結果付のpredicted_test.csvが同フォルダ内に生成されます。

bert_predict.png

6.まとめ

今回はローカル環境内にDRFとBERTの2値分類モデルを用いて簡単なネガポジを判定するREST APIを作成してみました。
今後は、Azure基盤上でREST APIを構築したり、2値分類ではなく多分類やFAQといったタスクにも応用してみたいと考えています。

7.参考書籍

本記事で作成したDRFは、現場で使える Django REST Framework の教科書 内のコンテンツをベースに少し応用して作成しました。
今回初めてこの書籍でDRFを学習しましたが、大変参考になる書籍ですのでこれからDRFを学習したいという方にはお勧めです。

明日は、ssh22さん さんの「Django Advent Calendar 2019 - Qiita」21日目の記事です。よろしくお願いします!

ysiny
元々インフラ畑で仕事をしていましたが、アプリケーション開発に興味を持ちpython、django、ディープラーニングの勉強を始めて約2年程です。 ※本職ではないため完全独学の範囲
https://sinyblog.com/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした