今回は scikit-learn を使ったナイーブベイズ分類器によるテキスト分類を行います。
モデルは以前投稿した多項分布を使用します。
データセット
テキストのデータセットは、scikit-learn に用意されているニュースデータセットを利用します。
このデータには、ニュースのテキストデータと分類される20種類のグループが格納されています。
実装コード
インポート
各種ライブラリをインポートします。
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline
データセット準備
sklearn.datasets の fetch_20newsgroups()
を使いデータ取得を行います。
- データ取得
data = fetch_20newsgroups()
- 分類グループ名表示
以下のように20種類のグループに分類されています。
コンピューターや宇宙などのグループがあります。
print(data.target_names)
> ['alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc',
'comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware',
'comp.windows.x',
'misc.forsale',
'rec.autos',
'rec.motorcycles',
'rec.sport.baseball',
'rec.sport.hockey',
'sci.crypt',
'sci.electronics',
'sci.med',
'sci.space',
'soc.religion.christian',
'talk.politics.guns',
'talk.politics.mideast',
'talk.politics.misc',
'talk.religion.misc']
- 学習・テストデータを取得
トレーニングと評価用データを取得します。
train = fetch_20newsgroups(subset='train')
test = fetch_20newsgroups(subset='test')
- データ件数
データ件数は1万件以上あります。
print("Train Data Count =", len(train.data))
print("Test Data Count =", len(test.data))
> Train Data Count = 11314
Test Data Count = 7532
- データの中身を確認
タイトルや本文などのテキストデータが格納されています。
print(train.data[10][:500])
> From: ******.*** (Irwin Arnstein)
Subject: Re: Recommendation on Duc
Summary: What's it worth?
Distribution: usa
Expires: Sat, 1 May 1993 05:00:00 GMT
Organization: CompuTrac Inc., Richardson TX
Keywords: Ducati, GTS, How much?
Lines: 13
I have a line on a Ducati 900GTS 1978 model with 17k on the clock. Runs
very well, paint is the bronze/brown/orange faded out, leaks a bit of oil
and pops out of 1st with hard accel. The shop will fix trans and oil
leak. They sold the bike t
テキストデータのベクトル化
テキストデータはそのまま特徴量としては使えないため、
文書中に含まれる単語の重要度を評価する手法の1つである
TF-IDF(索引語頻度逆文書頻度)という手法を使い、特徴量を算出します。
この算出には sklearn に用意されている TfidfVectorizer
クラスを使います。
※パイプライン内で使用します。
パイプライン(データ前処理+モデル生成)
先ほどの TF-IDF(TfidfVectorizer) と ナイーブベイズ分類器(MultinomialNB) の処理を
scikit-learn のパイプラインを使ってまとめて処理します。
パイプラインは前処理を行ったデータをモデルに流し込んで使うなどの一連の処理を
簡単に扱う事が出来るとても便利な機能になります。
model = make_pipeline(TfidfVectorizer(), MultinomialNB())
学習実行
パイプラインにて、入力テキストデータに対して TF-IDF の特徴量算出処理が行われ、
その特徴量が MultinomialNB の入力データとして学習が実行されます。
model.fit(train.data, train.target)
評価
作成したモデルを使い、評価を実行します。
print('Train accuracy = %.3f' % model.score(train.data, train.target))
print(' Test accuracy = %.3f' % model.score(test.data, test.target))
>Train accuracy = 0.925
Test accuracy = 0.773
精度は、トレーニングデータをそのまま使った場合は、0.925
評価データを使った場合では、0.773
となりました。
まずまずの精度となりました。
予測相関
テストデータの予測ラベルと正解ラベルの一致度を相関図で確認してみます。
# 表示サイズ
plt.rcParams['figure.figsize'] = (15.0, 15.0)
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(test.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
xticklabels=train.target_names, yticklabels=train.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label')
概ね正解しているようですが、soc.religion.christian
(キリスト教)のグループ分類予測が
alt.atheism
(無神論)と talk.religion.misc
(宗教:その他)と混同しているようです。
予測
次に、任意のテキストを入力データとして予測を行ってみます。
まず、予測メソッドを定義します。
def predicted_group(s, train=train, model=model):
pred = model.predict([s])
return train.target_names[pred[0]]
予測を実行します。
予測結果は以下の通りです。
predicted_group('A new Mac book was released.')
> 'comp.sys.mac.hardware'
predicted_group('I carry a gun for self-defense')
> 'talk.politics.guns'
predicted_group('At 2:32 a.m. EST, Crew Dragon undocked from the International Space Station to begin the final phase of its uncrewed Demo-1 flight test.')
> 'sci.space'
それとなく予測は出来ているようです。
以上、今回は scikit-learn を使ったナイーブベイズ分類器によるテキスト分類を行いました。