LoginSignup
15
23

More than 5 years have passed since last update.

Rで判別分析いろいろ(11種類)

Last updated at Posted at 2016-07-03

Background

Rにはパッケージがたくさんあって、ちょっとググっただけで、いろいろなやり方が見つかる。
でも、どの手法をどの目的で使うのか理解できていないと混乱してしまうので、まず、判別分析に絞って、目についた手法を列挙して精度を比較してみた。

Summary

使用したデータ

  • 4601通のメールをspamとnon-spamに分類してあるデータ。
  • 460通をテストデータ、残りを学習データに使った。

比較結果

手法 package 関数 正解率 チューニング
線形判別分析 MASS lda 0.896
非線形判別分析(2次式) MASS qda 0.837
k-Nearest Neighbor class knn 0.820 k=1
ナイーブベイズ e1071 naiveBayes 0.720
Decision tree rpart rpart 0.889
Nural Network 3層 nnet nnet 0.959 size=7, decay=0.3
Nural Network LVQ1 class lvq1 0.763 k=5, alpha=0.23
SVM e1071 svm 0.952 gamma=0.001381068, cost=512
Bagging adabag bagging 0.909
Boosting adabag boosting 0.959
Randam Forest randomForest randomForest 0.952 mtry=14

結論

  • 同一の問題にたいして、判別分析の手法をいろいろ試してみた
  • 3層ニューラルネット、SVM、Boosting、Random Forestがトップ集団
  • 線形判別分析も意外とよい順位になった。

感想

  • 判別の理由を説明しやすいもの、ブラックボックス化して説明しにくいものいろいろある
  • チューニングによって結構精度が変わる。もっとちゃんとやるともっと良くなるかも
  • 問題を変えてみると、順位が変わるのかも

References...1つだけ読むなら

https://www1.doshisha.ac.jp/~mjin/R/index.html
でも、しっかりやるなら本の方がおすすめ
https://www.amazon.co.jp/Rによるデータサイエンス-データ解析の基礎から最新手法まで-金-明哲/dp/4627096011

Details

使用したデータ

library(kernlab)            # spamのデータセットを使うために呼び出している
num<-10*(1:(nrow(spam)/10)) # 全体の10分の1を検証に使用。等間隔に抜き出し。
data.test<-spam[num,]       # 検証用データ
data.train<-spam[-num,]     # 学習用データ

線形判別分析

#
# 線形判別分析
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(MASS)
M<-lda(type~.,data=data.train)

# predict
P=predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P$class)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.lda<-round(a,3)
A.lda

非線形判別分析(2次式)

#
# 非線形判別分析(2次式)
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(MASS)
M<-qda(type~.,data=data.train)

# predict
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P$class)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.qda<-round(a,3)
A.qda

k-Nearest Neighbor

#
# 非線形判別分析 k-Nearest Neighbor
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model / prediction
library(class)
P<-knn(data.train[,-58],data.test[,-58],data.train[,58],k=1)

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.knn<-a
A.knn

ナイーブベイズ

#
# ナイーブベイズ
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(e1071)
M<-naiveBayes(type~.,data.train)

# predict
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.nbayes<-round(a,3)
A.nbayes

Decision tree

#
# Decision tree
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(rpart)
M<-rpart(type~.,data=data.train)

# predict
P<-predict(M,data.test[,-58],type="class")

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.dtree<-round(a,3)
A.dtree

Nural Network 3層

#
# 3 layer nural network (nnet package)
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]               # for test
data.train<-data[-num,]             # for model

# Nural Network
library(nnet)

# model
M<-nnet(type~.,size=7,decay=0.3,data=data.train)

# prediction
P<-predict(M,data.test[,-58],type="class")

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.nnet<-round(a,3)
A.nnet

Nural Network LVQ1

#
# Nural Networl (Learning Vector Quantization)
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]               # for test
data.train<-data[-num,]             # for model

# model
library(class)
Minit<-lvqinit(data.train[,-58],data.train[,58],k = 5)
M<-lvq1(data.train[,-58],data.train[,58],Minit, alpha = 0.23)

# prediction
P<-lvqtest(M,data.test[,-58])

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.lvq1<-round(a,3)
A.lvq1

SVM

#
# Support Vector Machine
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]               # for test
data.train<-data[-num,]             # for model

# model
library(e1071)
M <- svm(
  type ~ .,
  data  = data.train,
  gamma = 0.001381068,
  cost  = 512
)

# prediction
P <- predict(M, data.test)

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1]  + t[2,2] + t[1,2] + t[2,1])
A.svm<-round(a,3)
A.svm

Bagging

#
# Bagging
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model
library(adabag)
M<-bagging(type~.,data=data.train)

# prediction
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P$class)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.bag<-round(a,3)
A.bag

Boosting

#
# Boosting
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]
data.train<-data[-num,]

# model 
library(adabag)
M<-boosting(type~.,data=data.train)

# prediction
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P$class)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.boost<-round(a,3)
A.boost

Randam Forest

#
# Random Forest
#

# dataset
library(kernlab)
data(spam)
dim(spam)
data<-spam
num<-10*(1:(nrow(data)/10))
data.test<-data[num,]               # for test
data.train<-data[-num,]             # for model

# model
library(randomForest)
M<-randomForest(type~.,data=data.train,na.action="na.omit",mtry=14)

# prediction
P<-predict(M,data.test[,-58])

# result
t<-table(data.test[,58],P)
a<-(t[1,1] + t[2,2]) / (t[1,1] + t[2,2] + t[1,2] + t[2,1])
A.rf<-round(a,3)
A.rf

References

  1. https://www.amazon.co.jp/Rによるデータサイエンス-データ解析の基礎から最新手法まで-金-明哲/dp/4627096011
  2. https://www1.doshisha.ac.jp/~mjin/R/index.html
15
23
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
23