pandasで集計結果を条件とした抽出を行います。
集計結果の受け取り方に注意が必要です。
以下のデータを例にします。
import pandas as pd
df=pd.DataFrame({'cust_id': ['A1', 'A2', 'A3', 'A4', 'A5','A6'],
'amount': [100, 500, 300, 200, 200,500]})
amountの平均をとってそれ以上のレコードだけを抽出します。
平均は以下のようにmeanという集計関数で取得できます。
df.mean()
amount 300.0
dtype: float64
しかし、以下のコードで平均以上を抽出しようとすると「ValueError: Can only compare identically-labeled Series objects」が発生して失敗します。
df[df['amount']>=df.mean()]
これはdf.mean()がSeriesオブジェクトだからです。
print(type(df.mean()))
<class 'pandas.core.series.Series'>
なぜSeriesオブジェクトなのかは複数の数値列があるデータにするとわかります。
以下のデータはamountのほかにcountという数値列を追加しました。
df=pd.DataFrame({'cust_id': ['A1', 'A2', 'A3', 'A4', 'A5','A6'],
'amount': [100, 500, 300, 200, 200,500],
'count': [1, 5, 3, 2, 2,5]})
df.mean()
amount 300.0
count 3.0
dtype: float64
amoutの平均とともにcountの平均値も戻りました。このように複数の集計結果が入りうるのでSeriesオブジェクトとして戻るのです。
ではamountの平均値だけを求めます。2つ方法があります。
まずはdf.amountでSeriesオブジェクトにしてからmean()で集計する方法です。
print(df.amount.mean())
print(type(df.amount.mean()))
df[df['amount']>=df.amount.mean()]
もう一つは、df.mean()でSeriesオブジェクトを得てから、['amount']とインデックスアクセスして単一値を得る方法です。
print(df.mean()['amount'])
print(type(df.mean()['amount']))
df[df['amount']>=df.mean()['amount']]
集計結果を条件とした抽出はこのように集計結果を単一値で得てから、ブールインデックスやqueryで検索します。上記はブールインデックスでの例でしたが、queryを使う場合は以下のように@で集計結果を参照します。
df.query("amount>=@df.amount.mean()")
ちなみにcust_idとamountだけでよければSeriesオブジェクトのまま集計と抽出をすることも可能です。
set_index('cust_id')でcust_idはインデックス化して、['amount']でSeriesオブジェクトに変換しています。Seriesオブジェクトは列がひとつしかないので、sr.mean()で平均値が単一値として求まります。
sr=df.set_index('cust_id')['amount']
sr[sr>=sr.mean()]
cust_id
A2 500
A3 300
A6 500
Name: amount, dtype: int64
データサイエンス100本ノック(構造化データ加工編)のP-035を解くときにちょっと引っかかったのでメモしました。