はじめに
色々ドキュメントを読んでmlr3
が他の機械学習パッケージのラッパーであることをようやく理解しました.
準備
テキストマイニングにはtm
パッケージではなくtidytext
パッケージを使うことにします.
install.packages("stopwords", "tidytext", "wordcloud", "SnowballC")
library(mlr3verse)
library(stopwords)
library(tidyverse)
library(tidytext)
library(wordcloud)
library(SnowballC)
4.2 実例 -- 単純ベイズ法によるSMSスパムのフィルタリング
データを読み込みます.
url <- "http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip"
file_name <- "smsspamcollection.zip"
download.file(url, destfile = file_name)
unzip(file_name)
sms_raw <-
read_delim("SMSSpamCollection",
delim = "\t",
quote = "",
col_names = c("type", "text")) %>%
mutate(message = row_number())
type
列をcharacter
型からfactor
型に変換しておきます.
sms_raw$type <- factor(sms_raw$type)
str(sms_raw)
## spec_tbl_df [5,574 x 3] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
## $ type : Factor w/ 2 levels "ham","spam": 1 1 2 1 1 2 1 1 2 2 ...
## $ text : chr [1:5574] "Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat..." "Ok lar... Joking wif u oni..." "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question("| __truncated__ "U dun say so early hor... U c already then say..." ...
## $ message: int [1:5574] 1 2 3 4 5 6 7 8 9 10 ...
## - attr(*, "spec")=
## .. cols(
## .. type = col_character(),
## .. text = col_character()
## .. )
## NULL
ham
とspam
の比率を確認します.
table(sms_raw$type)
##
## ham spam
## 4827 747
sms_raw
をもとにテキストデータのクリーニングをします.
unnest_tokens
関数でメッセージテキストを単語ごとに分割し,新しいtibble
のword
列へ格納します.
wordStem
関数で単語を原形に変換します.
それから数字や空白を取り除き,get_stopwords
関数で出力したストップワードの一覧をanti_join
関数によって取り除いています.
tidy_sms <- sms_raw %>%
unnest_tokens(word, text) %>%
mutate(word = wordStem(word)) %>%
filter(!str_detect(word, "^[0-9]")) %>%
filter(word != "") %>%
anti_join(get_stopwords())
tidy_sms
## # A tibble: 52,890 x 3
## type message word
## <fct> <int> <chr>
## 1 ham 1 go
## 2 ham 1 jurong
## 3 ham 1 point
## 4 ham 1 crazi
## 5 ham 1 avail
## 6 ham 1 onli
## 7 ham 1 bugi
## 8 ham 1 n
## 9 ham 1 great
## 10 ham 1 world
## # ... with 52,880 more rows
wordcloud
関数でワードクラウドを作成します.
tidy_sms %>%
count(word) %>%
with(wordcloud(word, n, min.freq = 50, random.order = FALSE))
tidy_sms %>%
filter(type == "spam") %>%
count(word) %>%
with(wordcloud(word, n, max.words = 40, scale = c(3, 0.5)))
tidy_sms %>%
filter(type == "ham") %>%
count(word) %>%
with(wordcloud(word, n, max.words = 40, scale = c(3, 0.5)))
5通未満のメッセージにしか含まれていないような単語を削除します.
tidy_sms_freq <-
tidy_sms %>%
group_by(word) %>%
mutate(included = length(unique(message))) %>%
filter(included >= 5) %>%
ungroup() %>%
count(word, message) %>%
arrange(message)
head(tidy_sms_freq)
## # A tibble: 6 x 3
## word message n
## <chr> <int> <int>
## 1 avail 1 1
## 2 bugi 1 1
## 3 cine 1 1
## 4 crazi 1 1
## 5 e 1 1
## 6 go 1 1
DTM(Document-Term Matrix)を作成します.
最初にDTMのセルには各単語の出現回数を入れておき,これが正であるかどうかで値の変わるfactor
型に変換します.
messages <- unique(tidy_sms_freq$message)
convert_counts <- function(x) {
x <- ifelse(x > 0, "Yes", "No")
}
sms_dtm <-
as_tibble(
as.matrix(cast_dtm(tidy_sms_freq, message, word, n)),
.name_repair = "unique") %>%
map(convert_counts) %>%
unclass() %>%
as.data.frame(stringsAsFactors = TRUE) %>%
as_tibble()
head(sms_dtm)
## # A tibble: 6 x 1,564
## avail bugi cine crazi e go got great la n onli point wat
## <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct>
## 1 Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes
## 2 No No No No No No No No No No No No No
## 3 No No No No No No No No No No No No No
## 4 No No No No No No No No No No No No No
## 5 No No No No No No No No No No No No No
## 6 No No No No No No No No No No No No No
## # ... with 1,551 more variables: world <fct>, joke <fct>, lar <fct>, ok <fct>,
## # u <fct>, wif <fct>, appli <fct>, c. <fct>, comp <fct>, cup <fct>,
## # entri <fct>, final <fct>, free <fct>, mai <fct>, question <fct>,
## # rate <fct>, receiv <fct>, std <fct>, t <fct>, text <fct>, txt <fct>,
## # win <fct>, wkly <fct>, alreadi <fct>, c <fct>, dun <fct>, earli <fct>,
## # sai <fct>, around <fct>, goe <fct>, live <fct>, nah <fct>, think <fct>,
## # though <fct>, usf <fct>, back <fct>, darl <fct>, freemsg <fct>, fun <fct>,
## # hei <fct>, it. <fct>, like <fct>, now <fct>, send <fct>, still <fct>,
## # tb <fct>, word <fct>, xxx <fct>, brother <fct>, even <fct>, speak <fct>,
## # thei <fct>, treat <fct>, caller <fct>, callertun <fct>, copi <fct>,
## # friend <fct>, ha <fct>, mell <fct>, per <fct>, press <fct>, request <fct>,
## # set <fct>, call <fct>, claim <fct>, code <fct>, custom <fct>, hour <fct>,
## # network <fct>, prize <fct>, reward <fct>, select <fct>, valid <fct>,
## # valu <fct>, winner <fct>, camera <fct>, co <fct>, colour <fct>,
## # entitl <fct>, latest <fct>, mobil <fct>, month <fct>, r <fct>, updat <fct>,
## # anymor <fct>, enough <fct>, gonna <fct>, home <fct>, i.v <fct>, k <fct>,
## # soon <fct>, stuff <fct>, talk <fct>, thi <fct>, todai <fct>, tonight <fct>,
## # want <fct>, cash <fct>, chanc <fct>, cost <fct>, ...
タスクのバックエンドを作成します.
sms_backend <- tibble(message_type = sms_raw$type[messages], sms_dtm)
head(sms_backend)
## # A tibble: 6 x 1,565
## message_type avail bugi cine crazi e go got great la n onli
## <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct>
## 1 ham Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes Yes
## 2 ham No No No No No No No No No No No
## 3 spam No No No No No No No No No No No
## 4 ham No No No No No No No No No No No
## 5 ham No No No No No No No No No No No
## 6 spam No No No No No No No No No No No
## # ... with 1,553 more variables: point <fct>, wat <fct>, world <fct>,
## # joke <fct>, lar <fct>, ok <fct>, u <fct>, wif <fct>, appli <fct>, c. <fct>,
## # comp <fct>, cup <fct>, entri <fct>, final <fct>, free <fct>, mai <fct>,
## # question <fct>, rate <fct>, receiv <fct>, std <fct>, t <fct>, text <fct>,
## # txt <fct>, win <fct>, wkly <fct>, alreadi <fct>, c <fct>, dun <fct>,
## # earli <fct>, sai <fct>, around <fct>, goe <fct>, live <fct>, nah <fct>,
## # think <fct>, though <fct>, usf <fct>, back <fct>, darl <fct>,
## # freemsg <fct>, fun <fct>, hei <fct>, it. <fct>, like <fct>, now <fct>,
## # send <fct>, still <fct>, tb <fct>, word <fct>, xxx <fct>, brother <fct>,
## # even <fct>, speak <fct>, thei <fct>, treat <fct>, caller <fct>,
## # callertun <fct>, copi <fct>, friend <fct>, ha <fct>, mell <fct>, per <fct>,
## # press <fct>, request <fct>, set <fct>, call <fct>, claim <fct>, code <fct>,
## # custom <fct>, hour <fct>, network <fct>, prize <fct>, reward <fct>,
## # select <fct>, valid <fct>, valu <fct>, winner <fct>, camera <fct>,
## # co <fct>, colour <fct>, entitl <fct>, latest <fct>, mobil <fct>,
## # month <fct>, r <fct>, updat <fct>, anymor <fct>, enough <fct>, gonna <fct>,
## # home <fct>, i.v <fct>, k <fct>, soon <fct>, stuff <fct>, talk <fct>,
## # thi <fct>, todai <fct>, tonight <fct>, want <fct>, cash <fct>, ...
タスクを作成し,バックエンドを訓練データとテストデータに分割します.
task <- TaskClassif$new(id = "sms",
backend = sms_backend,
target = "message_type")
learner <- lrn("classif.naive_bayes")
train_set <- seq_len(0.75 * task$nrow)
test_set <- setdiff(seq_len(task$nrow), train_set)
訓練データで学習します.
learner$train(task = task, row_ids = train_set)
テストデータで予測します.誤判定は21通です.
sms_test_pred <- learner$predict(task, row_ids = test_set)
sms_test_pred$confusion
# truth
## response ham spam
## ham 1198 17
## spam 4 165
パラメータにlaplace = 1
を指定して訓練してから予測してみます.
誤判定は24通に増えてしまいました.
learner <- lrn("classif.naive_bayes", laplace = 1)
learner$train(task = task, row_ids = train_set)
sms_test_pred2 <- learner$predict(task, row_ids = test_set)
sms_test_pred2$confusion
## truth
## response ham spam
## ham 1199 21
## spam 3 161
終わりに
DTMをどうやってバックエンドに突っ込んだらいいのかわからずに苦労しました.