128
121

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

新米データサイエンティストが覚えたての知識を使ってKaggleのタイタニックデータを分析し、投稿してみた!

Last updated at Posted at 2018-04-21

##はじめに
こんにちは、教育業界に就職した新米データサイエンティストです。
入社してから3週間、研修を始めてから2週間が経過しました。
通勤の中で、データサイエンスを学ぶのに良いコンテンツはないかと調べるのですが、よく目にするのは「Kaggleをやれ」という記事です。

データサイエンティストを目指して勉強するなら、Kaggleからはじめよう
『データサイエンティストとマシンラーニングエンジニアはKaggleやれ』というのは何故なのか

自分もKaggleには今年挑戦してみたいと考えていたのですが、まだその段階ではないと思っていました。そんな中、こんな記事を見つけました。

KaggleのCTOが教えてくれた”AIエンジニアに超オススメな8つの学習ステップ”
原文:What are the best sources to study machine learning and artificial intelligence?

こちらの記事によると、8つの学習ステップは

  1. あなたが興味を持っているデータ分析の問題を選ぶ
  2. その問題を、Quick & Dirtyに、ハックして、最初から最後まで一貫して解決する
  3. 自分の作った学習モデルを進化させて改善する
  4. 自分のソリューションを共有する
  5. さまざまな問題のセットで#1-4を繰り返す
  6. Kaggleのコンペに真剣に挑戦する
  7. 専門レベルを上げて機械学習を適用する
  8. 機械学習を他の人に教える

となっています。
ちょうど今週の研修でscikit-learnを使ってロジスティック回帰を回すということを学んだので(アルゴリズム理解はまだ)2まではできそうな気がしました。
そこで、今まで学んだことの総復習と考え、QuickではありませんがDirtyにKaggleのTitanicデータを分析してみようと思います。

** Kaggle公式ページ:Titanic: Machine Learning from Disaster
**

##データ観察、前処理
まずは、パッケージをインポートし、データを見てみます。

#パッケージインポート
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
%matplotlib inline

#データ読み込み
train_data = pd.read_csv('train.csv')
test_data = pd.read_csv('test.csv')
train_data.head()
スクリーンショット 2018-04-21 11.34.11.png

・PassengerId, Name, Ticketは個別の値っぽいので、不要そう
・Survivedが予想する目的変数
・Sex, Embarked, Cabinはダミー変数に変換した方が良い?
といったことがわかりました。

次に、.info()でNULLの有無やデータ型を見てみます。

train_data.info()
スクリーンショット 2018-04-21 12.09.02.png

・Ageの2割程がNULL
・Cabinの8割程がNULL
・Embarkedにも若干NULL有り
ということがわかりました。

Cabinに関しては欠損値が多すぎるということで今回の分析には使わないことにしました。
Ageの欠損値に関しては、平均値や中央値を当てはめるということが考えられますが、200個近く同じ値が入るのもおかしいのかなと思うので、NULLの部分を使わないことにしました。EmbarkedのNULLも同様に扱いました。

以上から、必要なデータのみ抽出します。

train_data_sub1 = train_data.drop(["PassengerId", "Name", "Ticket", "Cabin"], axis=1)
train_data_sub2 = train_data_sub1.dropna()
train_data_sub2.info()
スクリーンショット 2018-04-21 12.32.49.png

上手く処理できたようです。
それでは、次に.describe()を用いてそれぞれのカラムのデータ範囲を見てみます。

train_data_sub2.describe()
スクリーンショット 2018-04-21 12.35.45.png

・Age:最小値が0.42歳ということは赤ちゃんが乗っていて、半数以上が20歳~38歳みたいです。
・SibSpは乗っている兄弟、配偶者の数で、0人~5人
・Parchは乗っている子供、親の数で、0人~6人
・Fareは旅客運賃でほとんどの方が33($?)付近ですが、中には512($?)くらい高い運賃を払っている方もいるようです。これは、PclassやEmbarkedに関係するのかな...

ここで一つ思ったことが、SibSpやParchで一緒に乗っている人の人数がわかるので、もしかしたら同じチケット番号の旅客がいて、グループ分けできたのではないかということです。
子供がいるグループ、カップル、老人がいるグループなど分けられたら、予測の制度が高まるのではないかと思いましたが、ここはとりあえずモデルを作成し、テストデータに対して予測を行うということを優先し、先に進むことにします。

次に、データを集計、可視化して傾向を見てみます。

#Pclass別の生存率
pclass_groupby = pd.concat([train_data_sub2.groupby('Pclass')['Survived'].sum() / train_data_sub2.groupby('Pclass')['Survived'].count(), 
                            train_data_sub2.groupby('Pclass')['Survived'].count()], axis=1)
pclass_groupby.columns = ['Survived_rate', 'num_of_passenger']
pclass_groupby
スクリーンショット 2018-04-21 13.13.17.png

まず、Pclass(客室等級)ごとに生存率を見てみました。等級が高くなるほど、生存率が高いことが分かります。
次に、同様に性別でも見てみます。

#性別の生存率
sex_groupby = pd.concat([train_data_sub2.groupby('Sex')['Survived'].sum() / train_data_sub2.groupby('Sex')['Survived'].count(), 
                            train_data_sub2.groupby('Sex')['Survived'].count()], axis=1)
sex_groupby.columns = ['Survived_rate', 'num_of_passenger']
sex_groupby
スクリーンショット 2018-04-21 14.16.05.png

女性の生存率が男性の生存率に比べてとても高いことが分かりました。女性が優先して救助されたということでしょうか...?
次に、生存者と非生存者で年齢層の分布を比べてみます。

#生存, 非生存の年齢層分布
plt.figure(figsize=(15,6))
plt.hist(train_data_sub2.Age[train_data_sub2.Survived == 0], normed=True, bins=10, alpha=0.8, color='red', label='not survived')
plt.hist(train_data_sub2.Age[train_data_sub2.Survived == 1], normed=True, bins=10, alpha=0.7, color='blue', label='survived')
plt.grid(True)
plt.legend(loc='best', fontsize=15)
plt.xlabel('Age')

image.png

概ね年齢層の分布に差は有りませんが、0~8歳の割合が生存者は多いようです。実際に平均を比べてみると、

train_data_sub2.groupby('Survived', as_index=False)['Age'].mean()
スクリーンショット 2018-04-21 14.59.36.png

生存者の方が僅かに平均年齢が低いことが分かります。
次に、同様にして生存者と非生存者で運賃の分布を比べてみます。

#生存, 非生存の運賃分布
plt.figure(figsize=(15,6))
plt.hist(train_data_sub2.Fare[train_data_sub2.Survived == 0], normed=True, bins=10, alpha=0.8, color='red', label='not survived')
plt.hist(train_data_sub2.Fare[train_data_sub2.Survived == 1], normed=True, bins=20, alpha=0.7, color='blue', label='survived')
plt.grid(True)
plt.legend(loc='best', fontsize=15)
plt.xlabel('Fare')

image.png

すると、生存者は運賃を払っている人が多いと分かりました。
以上でデータを観察することに区切りをつけ、次にSexとEmbarkedをダミー変数変換しようと思います。

参考:pandasでダミー変数を作成する(get_dummies)

#ダミー変数取得
dummy_train = pd.get_dummies(train_data_sub2[['Sex', 'Embarked']])
dummy_train.head()
スクリーンショット 2018-04-21 15.23.36.png

上手くダミー変数を取得できています。次に、元のデータに結合します。

#元のデータに結合
train_data_sub3 = pd.concat([train_data_sub2.drop(['Sex', 'Embarked'], axis=1), dummy_train], axis=1)
train_data_sub3.head()
スクリーンショット 2018-04-21 15.25.39.png

無事結合できました。
最後に、今まで学習用データにしてきた返還をテストデータに対しても行います。

#テストデータも同様に処理
test_data_sub1 = test_data.drop(["PassengerId", "Name", "Ticket", "Cabin"], axis=1)
dummy_test = pd.get_dummies(test_data_sub1[['Sex', 'Embarked']])
test_data_sub2 = pd.concat([test_data_sub1.drop(['Sex', 'Embarked'], axis=1), dummy_test], axis=1)

これで準備が整いました。次にモデル作成、実際に予測をしていきます。

##モデル作成、評価

今回は冒頭でインポートしたロジスティック回帰を用いてみます。

#モデル作成
X_train = train_data_sub3.drop('Survived', axis=1)
y_train = train_data_sub3.Survived

log_reg = LogisticRegression()
log_reg.fit(X_train, y_train)

モデルを作成し、フィッティングを行いました。
では、学習用データに対して正解率を出力してみます。

#モデル評価
bench_mark = train_data_sub3.Survived.sum() / train_data_sub3.Survived.count()

print("bench_mark : ", bench_mark)
print("training score : ", log_reg.score(X_train, y_train))
bench_mark :  0.4044943820224719
training score :  0.7991573033707865

bench_markは、全て生存と予測した場合の正解率を表しています。モデルの正解率が約80%程なので、全て生存と予測するよりはいい結果が出ていると捉えることができます。
また、どの特徴量が効果的なのか見るため、特徴量の係数を出力しました。

feature_coef = pd.concat([df(X_train.columns), df(log_reg.coef_[0, :])], axis=1)
feature_coef.columns = ['feature name', 'coefficient']
feature_coef['abs_coefficient'] = abs(feature_coef.coefficient)
feature_coef.sort_values(by='abs_coefficient', ascending=False).drop('abs_coefficient', axis=1)
スクリーンショット 2018-04-21 16.09.51.png

データ観察の段階でも見えていましたが、女性であると生存確率は上がるようです。そして、客室等級の係数が負なので、等級が高いほど生存確率が高いということも分かります。他にも、Cherbourgから乗った方は生存確率が高いようです。
一方、年齢や運賃は生存確率にあまり影響がないということも分かりました。

それでは、テストデータに対して予測をしてみます。学習用データの時にはAgeのNULLを無視しましたが、テストデータの際にはNULLのままで予測ができないこと、年齢の影響が少ないことから、平均値を入れることにします。

test_data_sub2.Age = test_data_sub2.fillna(test_data_sub2.Age.mean())
test_data_sub2.Fare = test_data_sub2.fillna(test_data_sub2.Fare.mean())

FareにもNULLが有り、同様に処理をしました。
ではテストデータに対して予測を行い、それを出力してみます。

survived_predict = log_reg.predict(test_data_sub2)
survived_predict[:5]
array([0, 1, 0, 0, 1])

このように、予測が出力できました。それでは、submission用のファイルを作成し、実際にKaggleに投稿してみます!

##Submissionファイル作成、Kaggleに投稿!

Kaggleで指定されているフォーマットに合わせて作成します。

Submission File Format
You should submit a csv file with exactly 418 entries plus a header row. Your submission will show an error if you have extra columns (beyond PassengerId and Survived) or rows.
The file should have exactly 2 columns:
 ・PassengerId (sorted in any order)
 ・Survived (contains your binary predictions: 1 for survived, 0 for deceased)

submittion_file = pd.concat([df(test_data.PassengerId), df(survived_predict)], axis=1)
submittion_file.columns = ['PassengerId', 'Survived']
submittion_file.to_csv('submittion.csv', index=False)

Kaggleに提出すると、テストデータに対する正解率が表示されます。

スクリーンショット 2018-04-21 16.37.38.png

正解率は70%程でした。
初めての投稿、感動的です。。

##終わりに

今回はまずモデルを作成して回して予測をし、Kaggleに予測結果を提出するというところまでやってみました。
次のステップは、前述の参考記事によれば、モデルを進化させる段階です。これにはやはりアルゴリズムの理解が不可欠だと思います。(今はロジスティック回帰すら理解していない)
なので、その辺りを研修と平行して進めていければいいなと思っています。
拙い記事だったと思いますが、ここまで読んでくださった方、ありがとうございました!

128
121
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
128
121

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?