LoginSignup
12
6

生成AIでテキスト分類やってみた

Last updated at Posted at 2023-12-16

この記事はNTTコムウェア Advent Calendar 2023 17日目の記事です。

はじめに

NTTコムウェアの森吉と申します。
普段はBtoBシステム開発を実施しておりますが、
その傍らで生成AIを利用したシステム開発・維持の効率化を目指す施策推進チームでも活動しております。

生成AIの利用が活発ですね。
今回は、生成AIで教師あり学習のテキスト分類問題にチャレンジします。
こういった分類問題は、BERTのライブラリが有名ですが、
生成AIを使った記事はあまり見ませんので、自分で検証することにしました。

本記事では下記の記事を参考にしております。

方針としては、text-embedding-ada-002モデル(以下、embedding)を使ってベクトルDBにデータを蓄積しておき、蓄積したデータに対し識別器(k-NN法)を使って精度を評価していきます。

注意点としては、BERTはfine tuningしたモデルに対して分類をかけているのに対し、
embeddingは汎用的なモデルになっているので、単純な比較はできない点にご注意願います。

データセット

データセットはlivedoor ニュースコーパスを使用しております。

記事の種類ごとに、各800件程度のデータが用意されています。

トピックニュース / Sports Watch / ITライフハック
家電チャンネル / MOVIE ENTER / 独女通信
エスマックス / livedoor HOMME / Peachy

処理の流れ

  • ファイルを加工し前処理する(元記事様を参考にしてください)

  • ニュースコーパスをベクトルDBに登録する
  • 精度を検証する

環境情報

Azure OpenAI(text-embedding-ada-002)
python 3.9.6

ライブラリ(pipなどでinstallしてください)

openai==0.27.8
SQLAlchemy==2.0.9
urllib3==1.26.9
scikit-learn==1.3.0
pandas==1.5.3
pgvector==0.2.1

ニュースコーパスをベクトルDBに登録する

前処理にて、ニュースコーパスをCSVとして出力できましたので、
こちらをベクトル検索DBに登録します。
元記事様は事前にテストデータとトレーニングデータを分離していますが、
embeddingは利用の度に費用が発生しますので、全て登録し、後から分離します。
今回はPostgreSQLの拡張機能のpgvectorを使用します。

テーブルもカラムも下記コードの中で自動的にできます。
テーブルと設定内容下記のとおりです。

newsテーブル

物理名 設定値
id integer 自動採番されるので設定なし
body text ニュース記事本体
embedding vector(1536) ニュース記事をベクトル化したベクトルデータ
category text 正解ラベル

OpenAIとPostgreSQLの接続設定を記入すると動作します。
Ratelimitやトークン数オーバーで中断した場合は、RATELIMIT_RESUMEに停止した時の最終行数を設定してください。
※長い記事はモデルの制限上読み込めませんので、この記事では読み込みをSKIPします。

import openai
from pgvector.sqlalchemy import Vector
from sqlalchemy import create_engine, insert, Integer, Text
from sqlalchemy.orm import declarative_base, mapped_column, Session
from urllib.parse import quote_plus
import csv

openai.api_type = "azure"
openai.api_base = "https://{your-resource-name}.openai.azure.com/"
openai.api_version = "2023-09-01-preview"
openai.api_key = "your-api-key"
OPENAI_MODEL_NAME = "your-model-name"

USER = "your-db-user"
PASSWORD = "your-db-password"
HOST = "your-host-address"
DB = "your-db-name"
PORT = your-db-port

# ratelimitで中断した場合の再開用
RATELIMIT_RESUME = 0

# パスワードをURLエンコード p@ssword -> p%40ssword
password = quote_plus(PASSWORD)
url = f'postgresql+psycopg://{USER}:{password}@{HOST}:{PORT}/{DB}'

# DBのvector設定
engine = create_engine(url, echo=False)
with engine.connect() as conn:
    conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
    conn.commit()

# SQLAlchemyのベース設定
Base = declarative_base()

# ベース定義
class news(Base):
    __tablename__ = 'news'

    id = mapped_column(Integer, primary_key=True)
    body = mapped_column(Text)
    embedding = mapped_column(Vector(1536))
    category = mapped_column(Text)

# dropしたい場合
# Base.metadata.drop_all(engine)

# 初めて作る場合
Base.metadata.create_all(engine)

# csvの読み込み
with open('./all_eval.csv', 'r', encoding="utf_8_sig") as f:
    reader = csv.reader(f, delimiter=',', quotechar='"', escapechar="\\")
    line_count = 0  # 行数をカウントする変数を初期化

    for body, category in reader:
        line_count += 1  # 行数をインクリメント
        if line_count < RATELIMIT_RESUME:
            continue
        print(f"読み込んだ行数: {line_count}")  # 行数を表示
        
        session = Session(engine) # dbセッション
        
        # INSERT
        session.execute(
            insert(news),
            dict(
                body=body,
                embedding=openai.Embedding.create(input=body, engine=OPENAI_MODEL_NAME)['data'][0]['embedding'],
                category=category
            )
        )
        
        # COMMIT
        session.commit()

終了するとDBが出来上がりますので、確認してください。
※3件トークン数オーバーでした。

-- 全体の数
SELECT COUNT(*) FROM news;

 count 
-------
  7364
(1 row)

-- カテゴリごとの数
SELECT COUNT(1),category FROM news GROUP BY category ORDER BY count DESC;

 count |      category       
-------|-------------------
   900 | sports-watch
   870 | dokujo-tsushin
   869 | smax
   869 | movie-enter
   869 | it-life-hack
   864 | kaden-channel
   842 | peachy
   770 | topic-news
   511 | livedoor-homme
(9 rows)

テストデータの精度を検証する

scikit-learnを用いることで簡単に実装できます。
今回用いているk-NN法以外の識別器を選択することも可能です。

from pgvector.sqlalchemy import Vector
from sqlalchemy import create_engine, select, Integer, Text
from sqlalchemy.orm import declarative_base, mapped_column
from urllib.parse import quote_plus
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report

USER = "your-db-user"
PASSWORD = "your-db-password"
HOST = "your-host-address"
DB = "your-db-name"
PORT = your-db-port

# パスワードをURLエンコード p@ssword -> p%40ssword
password = quote_plus(PASSWORD)
url = f'postgresql+psycopg://{USER}:{password}@{HOST}:{PORT}/{DB}'

# DBの設定確認
engine = create_engine(url, echo=False)

# ベース作成
Base = declarative_base()

# ベース定義
class news(Base):
    __tablename__ = 'news'

    id = mapped_column(Integer, primary_key=True)
    body = mapped_column(Text)
    embedding = mapped_column(Vector(1536))
    category = mapped_column(Text)

# dbからデータを取得
with engine.connect() as conn:
    res = conn.execute(select(news)) # テーブル名を指定
    df = pd.DataFrame(res.fetchall())

# labelを数値に変換
le = LabelEncoder()
le = le.fit(df['category'])
df['category'] = le.transform(df['category'])

# データを分割
train, test = train_test_split(df, test_size=0.2, random_state=42)

# knn法
clf = KNeighborsClassifier(n_neighbors=5)
clf.fit(train['embedding'].tolist(), train['category'])
pred = clf.predict(test['embedding'].tolist())

# 分類ごとの精度確認
print(classification_report(test['category'], pred, target_names=le.classes_))

結果

全体の精度としては、n_neighbors=5で83%でした。
n_neighborsを5以上に上げてみましたが、
5のあたりが最大でそれ以上は精度が落ちる傾向にありそうです。

予想はしていましたが、
チューニング済みのBERTを使った分類よりも精度が出ないという結果になりました。
ただ見込みはありそうな結果ですね。

# n_neighbors=1
                precision    recall  f1-score   support

dokujo-tsushin       0.73      0.74      0.73       187
  it-life-hack       0.86      0.82      0.84       193
 kaden-channel       0.76      0.77      0.77       181
livedoor-homme       0.67      0.77      0.71        91
   movie-enter       0.89      0.96      0.93       167
        peachy       0.72      0.66      0.69       175
          smax       0.89      0.91      0.90       157
  sports-watch       0.90      0.94      0.92       176
    topic-news       0.74      0.64      0.69       146

      accuracy                           0.80      1473
     macro avg       0.80      0.80      0.80      1473
  weighted avg       0.80      0.80      0.80      1473

# n_neighbors=3
                precision    recall  f1-score   support

dokujo-tsushin       0.68      0.82      0.74       187
  it-life-hack       0.81      0.82      0.81       193
 kaden-channel       0.75      0.78      0.77       181
livedoor-homme       0.71      0.70      0.71        91
   movie-enter       0.89      0.96      0.93       167
        peachy       0.80      0.62      0.70       175
          smax       0.91      0.93      0.92       157
  sports-watch       0.90      0.97      0.94       176
    topic-news       0.84      0.63      0.72       146

      accuracy                           0.81      1473
     macro avg       0.81      0.80      0.80      1473
  weighted avg       0.81      0.81      0.81      1473

# n_neighbors=5
                precision    recall  f1-score   support

dokujo-tsushin       0.74      0.83      0.78       187
  it-life-hack       0.85      0.83      0.84       193
 kaden-channel       0.76      0.80      0.78       181
livedoor-homme       0.72      0.69      0.70        91
   movie-enter       0.87      0.97      0.92       167
        peachy       0.81      0.68      0.74       175
          smax       0.92      0.95      0.93       157
  sports-watch       0.87      0.99      0.93       176
    topic-news       0.88      0.61      0.72       146

      accuracy                           0.83      1473
     macro avg       0.82      0.82      0.82      1473
  weighted avg       0.83      0.83      0.82      1473

傾向

sports-watchとsmax、movie-enterにおいて、かなりの分類性能を示しているようです。
一方で、livedoor-homme、topic-newsに対する分類は苦手傾向にあるようです。
みたところ、livedoor-hommeは記事数そのものが少ない。
topic-newsは総合記事系のジャンルな様子なので分類が難しいのかもしれないです。

おわりに

どちらが適しているのかは、使い方によるというのが結論だと思います。
BERT: finetuning させて精度に関して最適化したいときに使う
embedding: 汎化性能の高さを活かして手軽に使う
など。

他のデータセットでも試してみたいところですが、今回はここまでとします。
今回ご協力いただきました皆様に感謝申し上げます。

記載されている会社名、製品名、サービス名は、各社の商標または登録商標です。

12
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
6