2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

mlr3で「Rによる機械学習」:第4章「確率的学習 -- 単純ベイズを使った分類」

Posted at

はじめに

色々ドキュメントを読んでmlr3が他の機械学習パッケージのラッパーであることをようやく理解しました.

準備

テキストマイニングにはtmパッケージではなくtidytextパッケージを使うことにします.

install.packages("stopwords", "tidytext", "wordcloud", "SnowballC")
library(mlr3verse)
library(stopwords)
library(tidyverse)
library(tidytext)
library(wordcloud)
library(SnowballC)

4.2 実例 -- 単純ベイズ法によるSMSスパムのフィルタリング

データを読み込みます.

url <- "http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip"
file_name <- "smsspamcollection.zip"
download.file(url, destfile = file_name)
unzip(file_name)
sms_raw <- 
  read_delim("SMSSpamCollection",
             delim = "\t",
             quote = "",
             col_names = c("type", "text")) %>%
  mutate(message = row_number())

type列をcharacter型からfactor型に変換しておきます.

sms_raw$type <- factor(sms_raw$type)
str(sms_raw)
## spec_tbl_df [5,574 x 3] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ type   : Factor w/ 2 levels "ham","spam": 1 1 2 1 1 2 1 1 2 2 ...
##  $ text   : chr [1:5574] "Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat..." "Ok lar... Joking wif u oni..." "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question("| __truncated__ "U dun say so early hor... U c already then say..." ...
##  $ message: int [1:5574] 1 2 3 4 5 6 7 8 9 10 ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   type = col_character(),
##   ..   text = col_character()
##   .. )
## NULL

hamspamの比率を確認します.

table(sms_raw$type)
## 
##  ham spam 
## 4827  747

sms_rawをもとにテキストデータのクリーニングをします.
unnest_tokens関数でメッセージテキストを単語ごとに分割し,新しいtibbleword列へ格納します.
wordStem関数で単語を原形に変換します.
それから数字や空白を取り除き,get_stopwords関数で出力したストップワードの一覧をanti_join関数によって取り除いています.

tidy_sms <- sms_raw %>%
  unnest_tokens(word, text) %>%
  mutate(word = wordStem(word)) %>%
  filter(!str_detect(word, "^[0-9]")) %>%
  filter(word != "") %>%
  anti_join(get_stopwords())
tidy_sms
## # A tibble: 52,890 x 3
##    type  message word  
##    <fct>   <int> <chr> 
##  1 ham         1 go    
##  2 ham         1 jurong
##  3 ham         1 point 
##  4 ham         1 crazi 
##  5 ham         1 avail 
##  6 ham         1 onli  
##  7 ham         1 bugi  
##  8 ham         1 n     
##  9 ham         1 great 
## 10 ham         1 world 
## # ... with 52,880 more rows

wordcloud関数でワードクラウドを作成します.

tidy_sms %>%
  count(word) %>%
  with(wordcloud(word, n, min.freq = 50, random.order = FALSE))

tidy_sms %>% 
  filter(type == "spam") %>%
  count(word) %>%
  with(wordcloud(word, n, max.words = 40, scale = c(3, 0.5)))

tidy_sms %>%
  filter(type == "ham") %>%
  count(word) %>%
  with(wordcloud(word, n, max.words = 40, scale = c(3, 0.5)))

6d9efbc4-0925-4dce-836f-fcafc718d521.png
8e304a99-ce1b-4dfa-86d6-d8dd5c4c124f.png
9a8f3cd5-dc8d-4cb6-912a-c919843477c4.png

5通未満のメッセージにしか含まれていないような単語を削除します.

tidy_sms_freq <- 
  tidy_sms %>%
  group_by(word) %>%
  mutate(included = length(unique(message))) %>%
  filter(included >= 5) %>%
  ungroup() %>%
  count(word, message) %>%
  arrange(message)
head(tidy_sms_freq)
## # A tibble: 6 x 3
##   word  message     n
##   <chr>   <int> <int>
## 1 avail       1     1
## 2 bugi        1     1
## 3 cine        1     1
## 4 crazi       1     1
## 5 e           1     1
## 6 go          1     1

DTM(Document-Term Matrix)を作成します.
最初にDTMのセルには各単語の出現回数を入れておき,これが正であるかどうかで値の変わるfactor型に変換します.

messages <- unique(tidy_sms_freq$message)
convert_counts <- function(x) {
  x <- ifelse(x > 0, "Yes", "No")
}

sms_dtm <- 
  as_tibble(
    as.matrix(cast_dtm(tidy_sms_freq, message, word, n)),
    .name_repair = "unique") %>%
  map(convert_counts) %>%
  unclass() %>%
  as.data.frame(stringsAsFactors = TRUE) %>%
  as_tibble()

head(sms_dtm)
## # A tibble: 6 x 1,564
##   avail bugi  cine  crazi e     go    got   great la    n     onli  point wat  
##   <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct>
## 1 Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes  
## 2 No    No    No    No    No    No    No    No    No    No    No    No    No   
## 3 No    No    No    No    No    No    No    No    No    No    No    No    No   
## 4 No    No    No    No    No    No    No    No    No    No    No    No    No   
## 5 No    No    No    No    No    No    No    No    No    No    No    No    No   
## 6 No    No    No    No    No    No    No    No    No    No    No    No    No   
## # ... with 1,551 more variables: world <fct>, joke <fct>, lar <fct>, ok <fct>,
## #   u <fct>, wif <fct>, appli <fct>, c. <fct>, comp <fct>, cup <fct>,
## #   entri <fct>, final <fct>, free <fct>, mai <fct>, question <fct>,
## #   rate <fct>, receiv <fct>, std <fct>, t <fct>, text <fct>, txt <fct>,
## #   win <fct>, wkly <fct>, alreadi <fct>, c <fct>, dun <fct>, earli <fct>,
## #   sai <fct>, around <fct>, goe <fct>, live <fct>, nah <fct>, think <fct>,
## #   though <fct>, usf <fct>, back <fct>, darl <fct>, freemsg <fct>, fun <fct>,
## #   hei <fct>, it. <fct>, like <fct>, now <fct>, send <fct>, still <fct>,
## #   tb <fct>, word <fct>, xxx <fct>, brother <fct>, even <fct>, speak <fct>,
## #   thei <fct>, treat <fct>, caller <fct>, callertun <fct>, copi <fct>,
## #   friend <fct>, ha <fct>, mell <fct>, per <fct>, press <fct>, request <fct>,
## #   set <fct>, call <fct>, claim <fct>, code <fct>, custom <fct>, hour <fct>,
## #   network <fct>, prize <fct>, reward <fct>, select <fct>, valid <fct>,
## #   valu <fct>, winner <fct>, camera <fct>, co <fct>, colour <fct>,
## #   entitl <fct>, latest <fct>, mobil <fct>, month <fct>, r <fct>, updat <fct>,
## #   anymor <fct>, enough <fct>, gonna <fct>, home <fct>, i.v <fct>, k <fct>,
## #   soon <fct>, stuff <fct>, talk <fct>, thi <fct>, todai <fct>, tonight <fct>,
## #   want <fct>, cash <fct>, chanc <fct>, cost <fct>, ...

タスクのバックエンドを作成します.

sms_backend <- tibble(message_type = sms_raw$type[messages], sms_dtm)
head(sms_backend)
## # A tibble: 6 x 1,565
##   message_type avail bugi  cine  crazi e     go    got   great la    n     onli 
##   <fct>        <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct>
## 1 ham          Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes   Yes  
## 2 ham          No    No    No    No    No    No    No    No    No    No    No   
## 3 spam         No    No    No    No    No    No    No    No    No    No    No   
## 4 ham          No    No    No    No    No    No    No    No    No    No    No   
## 5 ham          No    No    No    No    No    No    No    No    No    No    No   
## 6 spam         No    No    No    No    No    No    No    No    No    No    No   
## # ... with 1,553 more variables: point <fct>, wat <fct>, world <fct>,
## #   joke <fct>, lar <fct>, ok <fct>, u <fct>, wif <fct>, appli <fct>, c. <fct>,
## #   comp <fct>, cup <fct>, entri <fct>, final <fct>, free <fct>, mai <fct>,
## #   question <fct>, rate <fct>, receiv <fct>, std <fct>, t <fct>, text <fct>,
## #   txt <fct>, win <fct>, wkly <fct>, alreadi <fct>, c <fct>, dun <fct>,
## #   earli <fct>, sai <fct>, around <fct>, goe <fct>, live <fct>, nah <fct>,
## #   think <fct>, though <fct>, usf <fct>, back <fct>, darl <fct>,
## #   freemsg <fct>, fun <fct>, hei <fct>, it. <fct>, like <fct>, now <fct>,
## #   send <fct>, still <fct>, tb <fct>, word <fct>, xxx <fct>, brother <fct>,
## #   even <fct>, speak <fct>, thei <fct>, treat <fct>, caller <fct>,
## #   callertun <fct>, copi <fct>, friend <fct>, ha <fct>, mell <fct>, per <fct>,
## #   press <fct>, request <fct>, set <fct>, call <fct>, claim <fct>, code <fct>,
## #   custom <fct>, hour <fct>, network <fct>, prize <fct>, reward <fct>,
## #   select <fct>, valid <fct>, valu <fct>, winner <fct>, camera <fct>,
## #   co <fct>, colour <fct>, entitl <fct>, latest <fct>, mobil <fct>,
## #   month <fct>, r <fct>, updat <fct>, anymor <fct>, enough <fct>, gonna <fct>,
## #   home <fct>, i.v <fct>, k <fct>, soon <fct>, stuff <fct>, talk <fct>,
## #   thi <fct>, todai <fct>, tonight <fct>, want <fct>, cash <fct>, ...

タスクを作成し,バックエンドを訓練データとテストデータに分割します.

task <- TaskClassif$new(id = "sms",
                        backend = sms_backend,
                        target = "message_type")
learner <- lrn("classif.naive_bayes")
train_set <- seq_len(0.75 * task$nrow)
test_set <- setdiff(seq_len(task$nrow), train_set)

訓練データで学習します.

learner$train(task = task, row_ids = train_set)

テストデータで予測します.誤判定は21通です.

sms_test_pred <- learner$predict(task, row_ids = test_set)
sms_test_pred$confusion
#         truth
## response  ham spam
##     ham  1198   17
##     spam    4  165

パラメータにlaplace = 1を指定して訓練してから予測してみます.
誤判定は24通に増えてしまいました.

learner <- lrn("classif.naive_bayes", laplace = 1)
learner$train(task = task, row_ids = train_set)
sms_test_pred2 <- learner$predict(task, row_ids = test_set)
sms_test_pred2$confusion
##         truth
## response  ham spam
##     ham  1199   21
##     spam    3  161

終わりに

DTMをどうやってバックエンドに突っ込んだらいいのかわからずに苦労しました.

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?