はじめに
Kaggle の Titanic で遊び始めているが, 欠損値の補完やハイパーパラメータの見直しの前に, まずデータをしっかり見ようと思い, データを眺めている. 読み込んだデータを, 例えば Survived
の値でグルーピングしてグラフを描くということをササッとやりたいのだが, なかなかうまくいかない. Pandas の "GroupBy" の理解が不十分だからだ.
ネットには, 先人たちのグラフ描画の例がたくさんあるが, 私の理解の道筋を記すことで, 初心者の役に立てるのではないか? と思って, この記事を書く.
目指すゴール
下記のようなグラフを描くこと.
このグラフは, 横軸が Ticket
の記号, 縦軸が生存 (s
), 死亡 (d
), 不明 (na
) の人数を積み上げたもので, 合計人数で降順にソートしている. 例えば, 一番左端の CA. 2343
のチケット記号は, 合計 11 人, 不明が 4 名, 残りの 7 名が死亡となっている.
こんなグラフをササッと描きたい.
データを読み込む
データを読み込んで, Ticket
のデータで, 同じ記号ごとの数を調べる.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
train_data = pd.read_csv("../train.csv")
test_data = pd.read_csv("../test.csv")
total_data = pd.concat([train_data, test_data]) # train_data と test_data を連結
ticket_freq = total_data["Ticket"].value_counts()
CA. 2343 11
CA 2144 8
1601 8
S.O.C. 14879 7
3101295 7
..
350404 1
248706 1
367655 1
W./C. 14260 1
350047 1
Name: Ticket, Length: 929, dtype: int64
CA. 2343
が 11 人, CA 2144
が 8 人, などが分かる.
グラフ用のデータを作る
groupby でグループ化
まず, total_data
をチケット記号でグルーピングする.
total_data_ticket = total_data.groupby("Ticket")
# 出力
<pandas.core.groupby.generic.DataFrameGroupBy object at 0x000001F5A14327C8>
groupby
の欠点は, データの中身を表示してくれないことだ. ここは, グループ化された と頭の中で理解して, 次へ行く.
生存情報だけ取り出す
次に, 生存情報 (Survived
) を取り出す.
total_data_ticket = total_data.groupby("Ticket")["Survived"]
total_data_ticket
# 出力
<pandas.core.groupby.generic.SeriesGroupBy object at 0x000001F5A1437B48>
ここでもデータは表示してくれない.
生存, 死亡, 不明ごとに数を数える
引き続き, value_counts()
を使って, Survived
の値ごとの数を数える. dropna=False
とすることで, N/A もカウントする.
total_data_ticket = total_data.groupby("Ticket")["Survived"].value_counts(dropna=False)
total_data_ticket
# 出力
Ticket Survived
110152 1.0 3
110413 1.0 2
0.0 1
110465 0.0 2
110469 NaN 1
..
W.E.P. 5734 NaN 1
0.0 1
W/C 14208 0.0 1
WE/P 5735 0.0 1
1.0 1
Name: Survived, Length: 1093, dtype: int64
データの形を変える
グラフを描くために, 生存, 死亡, 不明のデータが列方向に並ぶようなデータに変える. 使うのは unstack()
.
total_data_ticket = total_data.groupby("Ticket")["Survived"].value_counts(dropna=False).unstack()
total_data_ticket
# 出力
Survived NaN 0.0 1.0
Ticket
110152 NaN NaN 3.0
110413 NaN 1.0 2.0
110465 NaN 2.0 NaN
110469 1.0 NaN NaN
110489 1.0 NaN NaN
... ... ... ...
W./C. 6608 1.0 4.0 NaN
W./C. 6609 NaN 1.0 NaN
W.E.P. 5734 1.0 1.0 NaN
W/C 14208 NaN 1.0 NaN
WE/P 5735 NaN 1.0 1.0
929 rows × 3 columns
グラフを描く
N/A を数字に変える
上の出力を見ると, 値に NaN
がまだ残っている. そこで NaN
を 0 にする.
total_data_ticket.fillna(0, inplace=True)
total_data_ticket
# 出力
Survived NaN 0.0 1.0
Ticket
110152 0.0 0.0 3.0
110413 0.0 1.0 2.0
110465 0.0 2.0 0.0
110469 1.0 0.0 0.0
110489 1.0 0.0 0.0
... ... ... ...
W./C. 6608 1.0 4.0 0.0
W./C. 6609 0.0 1.0 0.0
W.E.P. 5734 1.0 1.0 0.0
W/C 14208 0.0 1.0 0.0
WE/P 5735 0.0 1.0 1.0
929 rows × 3 columns
列名を変える
列名が NaN
, 0.0
, 1.0
となっているが, これでは扱いにくいので, 列名を変える.
total_data_ticket.columns = ["nan", "d", "s"]
total_data_ticket
# 出力
nan d s
Ticket
110152 0.0 0.0 3.0
110413 0.0 1.0 2.0
110465 0.0 2.0 0.0
110469 1.0 0.0 0.0
110489 1.0 0.0 0.0
... ... ... ...
W./C. 6608 1.0 4.0 0.0
W./C. 6609 0.0 1.0 0.0
W.E.P. 5734 1.0 1.0 0.0
W/C 14208 0.0 1.0 0.0
WE/P 5735 0.0 1.0 1.0
929 rows × 3 columns
行ごとの合計人数を計算する
合計人数で降順にソートしたいので, 合計人数を計算して, 新しい列に保存する. 合計を計算するには sum()
を使うが, 列方向に計算するので sum(axis=1)
としている.
total_data_ticket["count"] = total_data_ticket.sum(axis=1)
total_data_ticket
#出力
nan d s count
Ticket
110152 0.0 0.0 3.0 3.0
110413 0.0 1.0 2.0 3.0
110465 0.0 2.0 0.0 2.0
110469 1.0 0.0 0.0 1.0
110489 1.0 0.0 0.0 1.0
... ... ... ... ...
W./C. 6608 1.0 4.0 0.0 5.0
W./C. 6609 0.0 1.0 0.0 1.0
W.E.P. 5734 1.0 1.0 0.0 2.0
W/C 14208 0.0 1.0 0.0 1.0
WE/P 5735 0.0 1.0 1.0 2.0
929 rows × 4 columns
これで, グラフを描く準備は整った.
グラフを描く
人数の領域を決めて, 降順にソートする
まずコードを示して, 順番に説明する.
total_data_ticket[total_data_ticket["count"] > 3].sort_values("count", ascending=False)[["nan", "d", "s"]].plot.bar(figsize=(15,10),stacked=True)
コード | 内容 |
---|---|
total_data_ticket[total_data_ticket["count"] > 3] |
"count" が 3 より大きいデータ |
.sort_values("count", ascending=False) |
"count" で降順にソート |
[["nan", "d", "s"]] |
左記の 3 つの列だけ取り出す ("count" はお役御免) |
.plot.bar(figsize=(15,10),stacked=True) |
棒グラフを描く. サイズを指定し, 積み上げ方式にした |
これで, 冒頭に示したグラフが書ける.
これを見ると, CA. 2343
や CA 2144
の人は Survived = 0
かな…とか想像できる.
全体のコード
最後に全体のコードを示す.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
train_data = pd.read_csv("../train.csv")
test_data = pd.read_csv("../test.csv")
total_data = pd.concat([train_data, test_data])
ticket_freq = total_data["Ticket"].value_counts()
ticket_freq
total_data_ticket = total_data.groupby("Ticket")["Survived"].value_counts(dropna=False).unstack()
total_data_ticket.fillna(0, inplace=True)
total_data_ticket.columns = ["nan", "d", "s"]
total_data_ticket["count"] = total_data_ticket.sum(axis=1)
total_data_ticket[total_data_ticket["count"] > 3].sort_values("count", ascending=False)[["nan", "d", "s"]].plot.bar(figsize=(15,10),stacked=True)
おわりに
この手法を使って, Embarked
や Cabin
, Name
の苗字や敬称など, 他の非数値データも確認していく.