0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

(Kaggle)決定木とランダムフォレストを使ったモデルを用いてTitanic生存者を予測した

Last updated at Posted at 2020-02-21

#1 はじめに
 機械学習を学ぶ上でのチュートリアルとして皆さん必ず通る道であろうタイタニック号生存者の予測について、私が行った方法を備忘として記録します。

使用したバージョンについて

  • Python 3.7.6
  • numpy 1.18.1
  • pandas  1.0.1
  • matplotlib 3.1.3
  • seaborn 0.10.0

使用したデータはKaggle登録後、こちらからダウンロードしました。
https://www.kaggle.com/c/titanic

#2 プログラムについて

#Importしたライブラリ等


import pandas as pd
import numpy as np
df = pd.read_csv('train.csv') 

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline 

データフレームdfとして定義しました。
ここで、読み込んだデータのうち最初の5行を見てみます。


df.head()

001.png

 Survivedが1で生存、0で死亡を表します。その他の要因が何を指しているかは公式サイトをご覧ください。

#ヒストグラム
次にヒストグラムを見てみます。


df.hist(figsize=(12,12))
plt.show()

002.png

 歳は20~30代が多く、Pclass(等級)は3(一番安い)が多いことが分かります。

#相関係数を確認


plt.figure(figsize = (15,15))
sns.heatmap(df.corr(),annot = True)

003.png

 Survivedに対して相関係数が高い指標は0.26:Fare、-0.34:Pclassとなります。Pclass等級が決まればFare:運賃も自ずと決まると思いますが、二つは違う指標なんですね。。

#欠損値の取り扱い


df.isnull().sum()

PassengerId 0
Survived 0
Pclass 0
Name 0
Sex 0
Age 177
SibSp 0
Parch 0
Ticket 0
Fare 0
Cabin 687
Embarked 2
dtype: int64

 Age:年齢とCabin:乗り組み番号が入っていない量が多いことが分かります。欠損値の取り扱いについては、下記を参考にしました。
https://qiita.com/0NE_shoT_/items/8db6d909e8b48adcb203

 今回は、年齢に関しては中央値を代入することとしました。また、Embarked:乗船地は最も多いSを代入しています。
 その他の欠損している値は削除しました。


from sklearn.model_selection import  train_test_split
#欠損値処理
df['Fare'] = df['Fare'].fillna(df['Fare'].median())
df['Age'] = df['Age'].fillna(df['Age'].median())
df['Embarked'] = df['Embarked'].fillna('S')

#カテゴリ変数の変換
df['Sex'] = df['Sex'].apply(lambda x: 1 if x == 'male' else 0)
df['Embarked'] = df['Embarked'].map( {'S': 0 , 'C':1 , 'Q':2}).astype(int)

#不要なcolumnを削除
df = df.drop(['Cabin','Name','PassengerId','Ticket'],axis =1)

#学習データとテストデータの分類


train_X = df.drop('Survived',axis = 1)
train_y = df.Survived
(train_X , test_X , train_y , test_y) = train_test_split(train_X, train_y , test_size = 0.3 , random_state = 0)

 今回は、テストデータを3割(test_size=0.3)としています。

#機械学習による予測(決定木)


from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(criterion='gini', random_state = 0)
clf = clf.fit(train_X , train_y)
pred = clf.predict(test_X)

#正解率の算出
from sklearn.metrics import (roc_curve , auc ,accuracy_score)
pred = clf.predict(test_X)
fpr, tpr, thresholds = roc_curve(test_y , pred,pos_label = 1)
auc(fpr,tpr)
accuracy_score(pred,test_y)

 決定木パラメータの取り扱いについてはこちらを参考にしました。
http://data-analysis-stats.jp/2019/01/14/%E6%B1%BA%E5%AE%9A%E6%9C%A8%E5%88%86%E6%9E%90%E3%81%AE%E3%83%91%E3%83%A9%E3%83%A1%E3%83%BC%E3%82%BF%E8%A7%A3%E8%AA%AC/

criterion='gini', 0.7798507462686567
criterion='entropy', 0.7910447761194029

若干entropyの方が正解率が上がりました。参考urlによると、

使い分けのポイントとしては、ジニ係数の方が、連続データを得意としており、エントロピーはカテゴリーデータを得意としていると言われています。ジニ係数は、誤分類を最小化するのに対して、エントロピーは探索的に基準値を探していきます。

とのことです。今回は性別、乗船地等カテゴリーデータが多かったことからエントロピーが適切なのでしょう。

こちらも決定木に関して詳しく載っており参考になりました。
https://qiita.com/3000manJPY/items/ef7495960f472ec14377

#機械学習による予測(ランダムフォレスト)


from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(n_estimators = 10,max_depth=5,random_state = 0) #ランダムフォレストのインスタンスを生成する。
clf = clf.fit(train_X , train_y) #教師ラベルと教師データを用いてfitメソッドでモデルを学習
pred = clf.predict(test_X)
fpr, tpr , thresholds = roc_curve(test_y,pred,pos_label = 1)
auc(fpr,tpr)
accuracy_score(pred,test_y)

0.8283582089552238

 ランダムフォレストには、分類用(Classifier)と回帰分析用(Regressor)があります。今回は、生存か死亡かを分類することが目的なため分類用を用いて実行します。

 RandomForestClassifierクラスの学習パラメータ、引数はこちらを見ながら動かしてみました。
https://data-science.gr.jp/implementation/iml_sklearn_random_forest.html

 n_estimatorsを上げていくと正解率がサチレートしていきました。n_estimatorsやニューラルネットワークで使用するepochsは大きな値で計算しても取り越し苦労に終わることが多いと言われています。その内容、対策については下記をご覧ください。
https://amalog.hateblo.jp/entry/hyper-parameter-search

#予測結果の出力

 先ほどの結果から決定木よりランダムフォレストのほうがより正解率が高いことが分かりました。
 従って、今回はランダムフォレストにより予測することとしました。


fin = pd.read_csv('test.csv')
fin.head()

passsengerid = fin['PassengerId']
fin.isnull().sum()
fin['Fare'] = fin['Fare'].fillna(fin['Fare'].median()) 
fin['Age'] = fin['Age'].fillna(fin['Age'].median())
fin['Embarked'] = fin['Embarked'].fillna('S')

#カテゴリ変数の変換
fin['Sex'] = fin['Sex'].apply(lambda x: 1 if x == 'male' else 0)
fin['Embarked'] = fin['Embarked'].map( {'S': 0 , 'C':1 , 'Q':2}).astype(int)

#不要なcolumnを削除
fin= fin.drop(['Cabin','Name','Ticket','PassengerId'],axis =1)
#ランダムフォレストで予測
predictions = clf.predict(fin)

submission = pd.DataFrame({'PassengerId':passsengerid, 'Survived':predictions})
submission.to_csv('submission.csv' , index = False)

004.png

4900位/16000位となりました。
もっと精進したいと思います。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?