Edited at

Rで判別分析いろいろ(11種類)

More than 3 years have passed since last update.


Background

Rにはパッケージがたくさんあって、ちょっとググっただけで、いろいろなやり方が見つかる。

でも、どの手法をどの目的で使うのか理解できていないと混乱してしまうので、まず、判別分析に絞って、目についた手法を列挙して精度を比較してみた。


Summary


使用したデータ


  • 4601通のメールをspamとnon-spamに分類してあるデータ。

  • 460通をテストデータ、残りを学習データに使った。


比較結果

手法
package
関数
正解率
チューニング

線形判別分析
MASS
lda
0.896

非線形判別分析(2次式)
MASS
qda
0.837

k-Nearest Neighbor
class
knn
0.820
k=1

ナイーブベイズ
e1071
naiveBayes
0.720

Decision tree
rpart
rpart
0.889

Nural Network 3層
nnet
nnet
0.959
size=7, decay=0.3

Nural Network LVQ1
class
lvq1
0.763
k=5, alpha=0.23

SVM
e1071
svm
0.952
gamma=0.001381068, cost=512

Bagging
adabag
bagging
0.909

Boosting
adabag
boosting
0.959

Randam Forest
randomForest
randomForest
0.952
mtry=14


結論


  • 同一の問題にたいして、判別分析の手法をいろいろ試してみた

  • 3層ニューラルネット、SVM、Boosting、Random Forestがトップ集団

  • 線形判別分析も意外とよい順位になった。


感想


  • 判別の理由を説明しやすいもの、ブラックボックス化して説明しにくいものいろいろある

  • チューニングによって結構精度が変わる。もっとちゃんとやるともっと良くなるかも

  • 問題を変えてみると、順位が変わるのかも


References...1つだけ読むなら

https://www1.doshisha.ac.jp/~mjin/R/index.html

でも、しっかりやるなら本の方がおすすめ

https://www.amazon.co.jp/Rによるデータサイエンス-データ解析の基礎から最新手法まで-金-明哲/dp/4627096011


Details


使用したデータ

https://cran.r-project.org/web/packages/kernlab/kernlab.pdf

library(kernlab)            # spamのデータセットを使うために呼び出している

num<-10*(1:(nrow(spam)/10)) # 全体の10分の1を検証に使用。等間隔に抜き出し。
data.test<-spam[num,] # 検証用データ
data.train<-spam[-num,] # 学習用データ


線形判別分析

#

# 線形判別分析
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(MASS)
M<-lda(type~.,data=data.train)

# predict
P=predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P$class)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.lda<-round(a,3)
A.lda


非線形判別分析(2次式)

#

# 非線形判別分析(2次式)
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(MASS)
M<-qda(type~.,data=data.train)

# predict
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P$class)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.qda<-round(a,3)
A.qda


k-Nearest Neighbor

#

# 非線形判別分析 k-Nearest Neighbor
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model / prediction
library(class)
P<-knn(data.train[,-58],data.test[,-58],data.train[,58],k=1)

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.knn<-a
A.knn


ナイーブベイズ

#

# ナイーブベイズ
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(e1071)
M<-naiveBayes(type~.,data.train)

# predict
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.nbayes<-round(a,3)
A.nbayes


Decision tree

#

# Decision tree
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(rpart)
M<-rpart(type~.,data=data.train)

# predict
P<-predict(M,data.test[,-58],type="class")

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.dtree<-round(a,3)
A.dtree


Nural Network 3層

#

# 3 layer nural network (nnet package)
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,] # for test
data.train<-data[-num,] # for model

# Nural Network
library(nnet)

# model
M<-nnet(type~.,size=7,decay=0.3,data=data.train)

# prediction
P<-predict(M,data.test[,-58],type="class")

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.nnet<-round(a,3)
A.nnet


Nural Network LVQ1

#

# Nural Networl (Learning Vector Quantization)
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,] # for test
data.train<-data[-num,] # for model

# model
library(class)
Minit<-lvqinit(data.train[,-58],data.train[,58],k = 5)
M<-lvq1(data.train[,-58],data.train[,58],Minit, alpha = 0.23)

# prediction
P<-lvqtest(M,data.test[,-58])

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.lvq1<-round(a,3)
A.lvq1


SVM

#

# Support Vector Machine
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,] # for test
data.train<-data[-num,] # for model

# model
library(e1071)
M <- svm(
type ~ .,
data = data.train,
gamma = 0.001381068,
cost = 512
)

# prediction
P <- predict(M, data.test)

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.svm<-round(a,3)
A.svm


Bagging

#

# Bagging
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(adabag)
M<-bagging(type~.,data=data.train)

# prediction
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P$class)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.bag<-round(a,3)
A.bag


Boosting

#

# Boosting
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(adabag)
M<-boosting(type~.,data=data.train)

# prediction
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P$class)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.boost<-round(a,3)
A.boost


Randam Forest

#

# Random Forest
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,] # for test
data.train<-data[-num,] # for model

# model
library(randomForest)
M<-randomForest(type~.,data=data.train,na.action="na.omit",mtry=14)

# prediction
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.rf<-round(a,3)
A.rf


References


  1. https://www.amazon.co.jp/Rによるデータサイエンス-データ解析の基礎から最新手法まで-金-明哲/dp/4627096011

  2. https://www1.doshisha.ac.jp/~mjin/R/index.html