3
1

タイタニックの欠損値を真面目に考えてみた

Posted at

生のデータにはたいてい欠損値が含まれますが、今回はその欠損値に統計解析と機械学習で補完するアプローチをしてみようと思います。

タイタニックデータ

import pandas as pd

df = pd.read_csv("train.csv")
df.head()

image.png

欠損値の確認

df.isnull().sum()
PassengerId      0
Survived         0
Pclass           0
Name             0
Sex              0
Age            177
SibSp            0
Parch            0
Ticket           0
Fare             0
Cabin          687
Embarked         2
dtype: int64

データ数見てみるとCabinは補完しようがないのでAgeとEmbarkedを補完します。

df.index = df["PassengerId"]
df = df.drop([
            'PassengerId',
            'Name',
            'Cabin',
            'Ticket'
            ],
            axis=1)

Embarkedの補完

import matplotlib.pyplot as plt
import seaborn as sns
sns.histplot(df["Embarked"])

image.png
Sが一番多かったためSで補完します。

df["Embarked"].fillna("S")

Ageの補完(訓練データ)

df = pd.get_dummies(df)
df.head()

image.png

train = df.copy()
train_train = train.dropna()
train_test = train[train["Age"].isnull()]
from lightgbm import LGBMRegressor
x_train = train_train.drop("Age", axis=1)
x_test = train_test.drop("Age", axis=1)
y_train = train_train["Age"]
model = LGBMRegressor()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)

予測値で補完します。

pred = pd.DataFrame(y_pred)
pred.index = train_test.index
pred.columns = ["Age"]
train_test["Age"] = pred["Age"]
train_test.isnull().sum()
Survived      0
Pclass        0
Age           0
SibSp         0
Parch         0
Fare          0
Sex_female    0
Sex_male      0
Embarked_C    0
Embarked_Q    0
Embarked_S    0
dtype: int64

欠損値がなくなりました。ここでデータを結合しておきます。

train = pd.concat([train_test, train_train])
train.sort_index().head(10)

テストデータの欠損値補完

test = pd.read_csv("test.csv")
test.isnull().sum()
PassengerId      0
Pclass           0
Name             0
Sex              0
Age             86
SibSp            0
Parch            0
Ticket           0
Fare             1
Cabin          327
Embarked         0
dtype: int64

ここでもCabinは多すぎるのでAgeとFareを補完します。

Fareの補完

plt.hist(test["Fare"], bins=100)
(array([  3., 152.,  68.,  16.,  26.,  44.,  11.,   9.,   4.,   5.,   9.,
         10.,   5.,   5.,   6.,   5.,   5.,   1.,   2.,   0.,   1.,   1.,
          0.,   0.,   0.,   0.,   6.,   0.,   1.,   2.,   0.,   0.,   2.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   5.,   0.,   3.,
          1.,   0.,   0.,   0.,   1.,   0.,   0.,   7.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          1.]),
 array([  0.      ,   5.123292,  10.246584,  15.369876,  20.493168,
         25.61646 ,  30.739752,  35.863044,  40.986336,  46.109628,
         51.23292 ,  56.356212,  61.479504,  66.602796,  71.726088,
         76.84938 ,  81.972672,  87.095964,  92.219256,  97.342548,
        102.46584 , 107.589132, 112.712424, 117.835716, 122.959008,
        128.0823  , 133.205592, 138.328884, 143.452176, 148.575468,
        153.69876 , 158.822052, 163.945344, 169.068636, 174.191928,
        179.31522 , 184.438512, 189.561804, 194.685096, 199.808388,
        204.93168 , 210.054972, 215.178264, 220.301556, 225.424848,
        230.54814 , 235.671432, 240.794724, 245.918016, 251.041308,
        256.1646  , 261.287892, 266.411184, 271.534476, 276.657768,
        281.78106 , 286.904352, 292.027644, 297.150936, 302.274228,
        307.39752 , 312.520812, 317.644104, 322.767396, 327.890688,
        333.01398 , 338.137272, 343.260564, 348.383856, 353.507148,
        358.63044 , 363.753732, 368.877024, 374.000316, 379.123608,
        384.2469  , 389.370192, 394.493484, 399.616776, 404.740068,
        409.86336 , 414.986652, 420.109944, 425.233236, 430.356528,
        435.47982 , 440.603112, 445.726404, 450.849696, 455.972988,
        461.09628 , 466.219572, 471.342864, 476.466156, 481.589448,
        486.71274 , 491.836032, 496.959324, 502.082616, 507.205908,
        512.3292  ]),

5.123292が一番多いので5.123292で補完します。

test["Fare"] = test["Fare"].fillna(5.123292)
test.isnull().sum()
PassengerId      0
Pclass           0
Name             0
Sex              0
Age             86
SibSp            0
Parch            0
Ticket           0
Fare             0
Cabin          327
Embarked         0
dtype: int64

Ageの補完

test.index = test["PassengerId"]
test = test.drop([
            'PassengerId',
            'Name',
            'Cabin',
            'Ticket'
            ],
            axis=1)
test = pd.get_dummies(test)
test = test.drop("Sex_female", axis=1)
test = pd.get_dummies(test)
test

image.png

test_train = test.dropna()
test_test = test[test["Age"].isnull()]
print(len(test_test))
test_test.isnull().sum()
86
Pclass         0
Age           86
SibSp          0
Parch          0
Fare           0
Sex_male       0
Embarked_C     0
Embarked_Q     0
Embarked_S     0
dtype: int6
x_train = test_train.drop("Age", axis=1)
x_test = test_test.drop("Age", axis=1)
y_train = test_train["Age"]
model = LGBMRegressor()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)

テストデータの結合

df_test = pd.concat([test_train, test_test])
age = pd.DataFrame(y_pred)
print(len(age), len(x_test))
age.index = test_test.index
age.columns = ["Age"]
test_test["Age"] = age["Age"]
test_test.isnull().sum()
Pclass        0
Age           0
SibSp         0
Parch         0
Fare          0
Sex_male      0
Embarked_C    0
Embarked_Q    0
Embarked_S    0
dtype: int64

欠損値がなくなりました。

データの結合

df_test = pd.concat([test_train, test_test])
df_test.isnull().sum()
Pclass        0
Age           0
SibSp         0
Parch         0
Fare          0
Sex_male      0
Embarked_C    0
Embarked_Q    0
Embarked_S    0
dtype: int64

生存の予測

df_train = train.copy()
from sklearn.model_selection import train_test_split
from lightgbm import LGBMClassifier
models = []
x = df_train.drop(["Survived", "Sex_female", "Embarked_C"], axis=1)
y = df_train["Survived"]
for i in range(200):
    x_train, x_val, y_train, y_val = train_test_split(x, y, random_state=i, test_size=0.3)
    model = LGBMClassifier()
    model.fit(x_train, y_train)
    models.append([model, model.score(x_val, y_val)])
models = sorted(models, key=lambda x:x[1], reverse=True)

精度の良かったモデルで予測します。

df_test = df_test.drop(["Embarked_C"], axis=1)
from scipy.stats import mode
model1 = models[0][0]
model2 = models[1][0]
model3 = models[2][0]
y_pred1 = model1.predict(df_test)
y_pred2 = model2.predict(df_test)
y_pred3 = model3.predict(df_test)
y_pred = []
for i in range(len(y_pred1)):
    y_pred.append(mode([y_pred1[i],y_pred2[i],y_pred3[i]])[0])

データフレーム化

df_pred = pd.DataFrame(y_pred)
df_pred.columns = ["Survived"]
df_pred.index = df_test.index
df_pred

提出用ファイル作成

df_pred.sort_index().to_csv("submission9.csv")

まとめ

精度は0.75358となりました。(チクショーめー!75%って低すぎやろ)

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1