25
32

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

機械学習 〜 テキスト分類(ナイーブベイズ分類器) 〜

Posted at

今回は scikit-learn を使ったナイーブベイズ分類器によるテキスト分類を行います。
モデルは以前投稿した多項分布を使用します。

データセット

テキストのデータセットは、scikit-learn に用意されているニュースデータセットを利用します。
このデータには、ニュースのテキストデータと分類される20種類のグループが格納されています。

実装コード

インポート

各種ライブラリをインポートします。

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import  fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline

データセット準備

sklearn.datasets の fetch_20newsgroups() を使いデータ取得を行います。

  • データ取得
data = fetch_20newsgroups()
  • 分類グループ名表示

以下のように20種類のグループに分類されています。
コンピューターや宇宙などのグループがあります。

print(data.target_names)

> ['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']
  • 学習・テストデータを取得

トレーニングと評価用データを取得します。

train = fetch_20newsgroups(subset='train')
test = fetch_20newsgroups(subset='test')
  • データ件数

データ件数は1万件以上あります。

print("Train Data Count =", len(train.data))
print("Test Data Count =", len(test.data))

> Train Data Count = 11314
  Test Data Count = 7532
  • データの中身を確認

タイトルや本文などのテキストデータが格納されています。

print(train.data[10][:500])

> From: ******.*** (Irwin Arnstein)
Subject: Re: Recommendation on Duc
Summary: What's it worth?
Distribution: usa
Expires: Sat, 1 May 1993 05:00:00 GMT
Organization: CompuTrac Inc., Richardson TX
Keywords: Ducati, GTS, How much? 
Lines: 13

I have a line on a Ducati 900GTS 1978 model with 17k on the clock.  Runs
very well, paint is the bronze/brown/orange faded out, leaks a bit of oil
and pops out of 1st with hard accel.  The shop will fix trans and oil 
leak.  They sold the bike t

テキストデータのベクトル化

テキストデータはそのまま特徴量としては使えないため、
文書中に含まれる単語の重要度を評価する手法の1つである
TF-IDF(索引語頻度逆文書頻度)という手法を使い、特徴量を算出します。

この算出には sklearn に用意されている TfidfVectorizer クラスを使います。

※パイプライン内で使用します。

パイプライン(データ前処理+モデル生成)

先ほどの TF-IDF(TfidfVectorizer) と ナイーブベイズ分類器(MultinomialNB) の処理を
scikit-learn のパイプラインを使ってまとめて処理します。
パイプラインは前処理を行ったデータをモデルに流し込んで使うなどの一連の処理を
簡単に扱う事が出来るとても便利な機能になります。

model = make_pipeline(TfidfVectorizer(), MultinomialNB())

学習実行

パイプラインにて、入力テキストデータに対して TF-IDF の特徴量算出処理が行われ、
その特徴量が MultinomialNB の入力データとして学習が実行されます。

model.fit(train.data, train.target)

評価

作成したモデルを使い、評価を実行します。

print('Train accuracy = %.3f' % model.score(train.data, train.target))
print(' Test accuracy = %.3f' % model.score(test.data, test.target))

>Train accuracy = 0.925
 Test accuracy = 0.773

精度は、トレーニングデータをそのまま使った場合は、0.925
評価データを使った場合では、0.773 となりました。
まずまずの精度となりました。

予測相関

テストデータの予測ラベルと正解ラベルの一致度を相関図で確認してみます。

# 表示サイズ
plt.rcParams['figure.figsize'] = (15.0, 15.0)

from sklearn.metrics import confusion_matrix
mat = confusion_matrix(test.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
            xticklabels=train.target_names, yticklabels=train.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label')

2019_03_text_vector_s3.png

概ね正解しているようですが、soc.religion.christian(キリスト教)のグループ分類予測が
alt.atheism(無神論)と talk.religion.misc(宗教:その他)と混同しているようです。

予測

次に、任意のテキストを入力データとして予測を行ってみます。
まず、予測メソッドを定義します。

def predicted_group(s, train=train, model=model):
    pred = model.predict([s])
    return train.target_names[pred[0]]

予測を実行します。
予測結果は以下の通りです。

predicted_group('A new Mac book was released.')
> 'comp.sys.mac.hardware'
predicted_group('I carry a gun for self-defense')
> 'talk.politics.guns'
predicted_group('At 2:32 a.m. EST, Crew Dragon undocked from the International Space Station to begin the final phase of its uncrewed Demo-1 flight test.')
> 'sci.space'

それとなく予測は出来ているようです。


以上、今回は scikit-learn を使ったナイーブベイズ分類器によるテキスト分類を行いました。


25
32
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
32

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?