LoginSignup
18
26

More than 5 years have passed since last update.

sklearnでナイーブベイズによるテキスト分類にチャレンジ

Last updated at Posted at 2016-03-28

実践 機械学習システム の第6章にナイーブベイズによるテキスト分類事例があったので、自分でもチャレンジしてみます。

やること

sklearnのデータセット 20newsgroupssklearn.naive_bayes.MultinomialNB 使ってカテゴリ分類します。

  1. CountVectorizer を利用して、 ドキュメントを単語出現頻度の行列に変換する
  2. MultinomialNB を利用して、ナイーブベイズ分類器を学習させる
  3. テストデータによる検証を行う

という流れになります。

実装

ストップワードの設定以外は全てデフォルトのままです。

import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer
import nltk

def stopwords():
    symbols = ["'", '"', '`', '.', ',', '-', '!', '?', ':', ';', '(', ')', '*', '--', '\\']
    stopwords = nltk.corpus.stopwords.words('english')
    return stopwords + symbols

newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
newsgroups_test  = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))

vectorizer = CountVectorizer(stop_words=stopwords())
vectorizer.fit(newsgroups_train.data)

# Train
X = vectorizer.transform(newsgroups_train.data)
y = newsgroups_train.target
print(X.shape)

clf = MultinomialNB()
clf.fit(X, y)
print(clf.score(X,y))

# Test
X_test = vectorizer.transform(newsgroups_test.data) 
y_test = newsgroups_test.target

print(clf.score(X_test, y_test))

結果

テストデータに対する正答率は 62% でした。
(なお、トレーニングデータに対しては 81%)

sklearnを使うと、ナイーブベイズ分類器を使ったテキスト分類が手軽にできることがわかりました。
ただ、正答率が 62% なので、精度を上げるためにはTfIdf、Stemming 等、諸々の自然言語処理を適用する必要がありそうです。

追記(2016/03/30)

TfidVectorizer に変更し、さらに GridSearchCVを使って最適なパラメタを探索する方法もやってみた。
テストデータに対する正答率は 66% と少し上昇。

import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.grid_search import GridSearchCV
from sklearn.pipeline import Pipeline
import nltk

def stopwords():
    symbols = ["'", '"', '`', '.', ',', '-', '!', '?', ':', ';', '(', ')', '*', '--', '\\']
    stopwords = nltk.corpus.stopwords.words('english')
    return stopwords + symbols

newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
newsgroups_test  = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))

# Pipeline
pipeline = Pipeline([('vectorizer', TfidfVectorizer()), ('multinomial_nb', MultinomialNB())])
params = {
    'vectorizer__max_df': [1.0, 0.99],
    'vectorizer__ngram_range': [(1,1), (1, 2)],
    'vectorizer__stop_words' : [stopwords()],
}
clf = GridSearchCV(pipeline, params)

# Train
X = newsgroups_train.data
y = newsgroups_train.target
clf.fit(X,y)
print(clf.score(X, y))

# Test
X_test = newsgroups_test.data
y_test = newsgroups_test.target
print(clf.score(X_test, y_test))
18
26
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
26