LoginSignup
2
3

More than 5 years have passed since last update.

kaggleトレーニングvol.1 ~Titanic: Machine Learning from Disaster~

Posted at

はじめに

@wasabi_sugiさんkaggleの和訳記事を書かれていて、自分も前からkaggleに興味を持っていたので記事を読んでみました。
ただ、自分で手と頭を動かさないと理解も進まないと思い、コードの写経と自分なりの解説を書いていくことに。
簡単に言うとパクリです、はい。

目的

自分の勉強のためです。

方針

kaggleのコンペからランダムで課題を選び、Kernel内のコードを自分なりに理解・解説しながら写経していきます。

注意

勘違い、誤りが含まれているかもしれません。もし何かあればコメントでご指摘頂けると幸いです。

本日のコンペとKernel

今回は、みんな大好きタイタニックから、こちらのコードを写経していきます。

コンペの概要

1912年に起きたタイタニック号の沈没事件。多くの人が亡くなったが、女性、子供、上層階の客など、他の人より生き残る可能性が高いグループもあった。このコンペでは、どういう人が生き残る可能性が高いのか予測してみましょう。

コード

library(ggplot2)
library(randomForest)

ライブラリーの読み込み。
randomForestは全く使ったことがないですし、その理論もよくわかっていません。早速そこから検索。このページが自分的にわかりやすかったです。
要は、判断基準(決定木)をたくさん作って、そこから良い感じの答えを一つ導き出そう、とかそんな感じだろうか。

set.seed(1)
train <- read.csv('data/train.csv', stringsAsFactors = FALSE)
test  <- read.csv('data/test.csv', stringsAsFactors = FALSE)

set.seed()は、決められた乱数を作るためのもの(矛盾!)。
何かしらの乱数を生成する前に、set.seed(1)等としておくと、次回以降も同じ乱数が生成できるよ、というもの。

read.csvはkaggleからダウンロードしてきた2つのcsvデータの読み込みですね。

extractFeatures <- function(data){
  features <- c('Pclass',
                'Age',
                'Sex',
                'Parch',
                'SibSp',
                'Fare',
                'Embarked')
  fea <- data[, features]
  fea$Age[is.na(fea$Age)] <- -1
  fea$Fare[is.na(fea$Fare)] <- median(fea$Fare, na.rm = TRUE)
  fea$Embarked[fea$Embarked == ''] = 'S'
  fea$Sex <- as.factor(fea$Sex)
  fea$Embarked <- as.factor(fea$Embarked)
  return(fea)
}

関数の定義。関数名extractFeaturesは「特徴を抽出する」という意味。この関数では、読み込んだcsvデータの欠損値の処理方法を定義しています。

featuresオブジェクトにPclass等を入れていますが、これは先程読み込んだ2つのcsvのカラム名になっています。

fea <- data[, features]
与えられたdataに関して、featuresと一致するカラムをfeaに入れる。

fea$Age[is.na(fea$Age)] <- -1
Age列の欠損値に対して、-1を代入。

fea$Fare[is.na(fea$Fare)] <- median(fea$Fare, na.rm = TRUE)
Fare(運賃)の欠損値に対して、Fareの欠損値を除く中央値を代入。

fea$Embarked[fea$Embarked == ''] = 'S'
Embarked(確か、乗客の乗船した港)の欠損値に対して、S(=Southampton)を代入

fea$Sex <- as.factor(fea$Sex)
Sex列をFactor型にする。

fea$Embarked <- as.factor(fea$Embarked)
Embarked列をFactor型にする。

return(fea)
欠損値の補間等を済ませた値を返す。

rf <- randomForest(extractFeatures(train),
                   as.factor(train$Survived),
                   ntree = 100,
                   importance = TRUE)

randomForestは、ランダムフォレストによる、分類と回帰を行ってくれる関数。

最初の引数extractFeatures(train)
予測のためのトレーニングデータを指定しています。

2つ目as.factor(train$Survived)
トレーニングデータに含まれる目的変数を指定しています。

3つ目ntree = 100
生成する決定木の数。Rのヘルプを見ると、「小さすぎる数を指定しないように」と記述がありました。この数を大きくしたり小さくすることが、予測精度の高低に関わりそう。

4つ目importance = TRUE
「予測変数の重要度を出力しますか?」ということのようです。デフォルトはFALSE。ここは基本TRUEとしておいたほうが、結果の判断に役立ちそうです。

submission <- data.frame(PassengerId = test$PassengerId)
submission$Survived <- predict(rf, extractFeatures(test))
write.csv(submission, file = "1_random_forest_r_submission.csv", row.names=FALSE)

submission <- data.frame(PassengerId = test$PassengerId)
テスト用(予測する側)のデータから、乗客IDを抜き出す。

submission$Survived <- predict(rf, extractFeatures(test))
先程作成したモデルrfを、テストデータに当てはめ、その予測結果を出す。

write.csv(submission, file = "1_random_forest_r_submission.csv", row.names=FALSE)
予測結果をcsv出力。

imp <- importance(rf, type=1)
featureImportance <- data.frame(Feature=row.names(imp), Importance=imp[,1])

importance()は、randomForestで計算された重要度を出力する関数。
typeはNULL,1,2のいずれかを指定できて、
NULL=下記両方
1=平均精度の低下
2=ノード不純物の平均減少
 ※グーグル翻訳
の結果が得られます。正直なんのことやら...

featureImportance <- data.frame(Feature=row.names(imp), Importance=imp[,1])
重要度のグラフ化のために、新たなデータフレームの作成。

p <- ggplot(featureImportance, aes(x = reorder(Feature, Importance), y = Importance)) +
  geom_bar(stat = 'identity', fill = '#53cfff') +
  coord_flip() +
  theme_light(base_size = 20) +
  xlab('') +
  ylab('Importance') +
  ggtitle('Random Forest Feature Importance\n') +
  theme(plot.title = element_text((size = 18)))

重要度のグラフ化。
私があまり使ったことなかった関数だけ記述します。
reorder
最初の引数を、2編目の引数の値に基づいて並び替えてくれる。つまり、昇順でのグラフ作成が可能。

theme_light(base_size = 20)
プロットエリア全体の体裁を変更する。便利だ。
lightの他にも、bwlinedrawなんかもある。詳しくは、ヘルプ参照。

ggsave("2_feature_importance.png", p)

作成したグラフの保存。
以上!

疑問点、改善点

  • set.seed(1)を使う理由。
    コメントにありましたが、randomForestはランダム性を持っているため、これを行わないとこのコードを実行する度に結果が変わるから、とのこと。

  • Ageの欠損値を-1で補完している。中央値、平均値などにしたら、もう少し精度が上がりそう。

  • ntree = 100にしている。ここも200,300,400と変えたら、精度が変わるのではないか。

おわりに

まだまだわからないことだらけ。
ただ、一歩一歩進めていくことが成長の近道だとも思うので、定期的に記事書いていきます。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3