やりたいこと
Polars
を使って、他のカラムを維持しながらgroup_by操作を行う方法を探しました。具体的には、mDice
を計算し、各Fold
ごとにmDice
が最大のEpoch
を取得したい、という内容です。
ぱっと思いつかなかったので、備忘録として記事にまとめます。
データ確認
以下のようなデータを対象としています。ここで、mDice
を計算し、各Fold
ごとにmDice
が最大の行を抽出したいとします。
Epoch | Fold | Label1 | Label2 | Label3 |
---|---|---|---|---|
1 | 1 | 0.8 | 0.75 | 0.85 |
2 | 1 | 0.85 | 0.78 | 0.9 |
3 | 1 | 0.88 | 0.8 | 0.92 |
1 | 2 | 0.76 | 0.74 | 0.78 |
2 | 2 | 0.84 | 0.8 | 0.85 |
3 | 2 | 0.9 | 0.85 | 0.88 |
data = {
"Epoch": [1, 2, 3, 1, 2, 3],
"Fold": [1,1,1,2,2,2],
"Label1": [0.8, 0.85, 0.88, 0.76, 0.84, 0.9],
"Label2": [0.75, 0.78, 0.8, 0.74, 0.8, 0.85],
"Label3": [0.85, 0.9, 0.92, 0.78, 0.85, 0.88],
}
df = pl.from_dict(data)
method A
最も簡潔な方法として、filter
を活用します。この方法では、over
を使ってFold
ごとにmDice
の最大値を計算し、その行を抽出します。
result = (
df
.with_columns(mDice = pl.mean_horizontal(pl.exclude("Epoch", "Fold")))
.filter(pl.col("mDice")==pl.max("mDice").over("Fold"))
)
水平方向の平均には便利な.mean_horizontal()
を使用します。
method B
group_by
でmDice
の最大値を計算し、それに対応する行をjoinで取得する方法です。この方法は、手続き的で構造が明確です。
tmp_df = df.with_columns(mDice = pl.mean_horizontal(pl.exclude("Epoch", "Fold")))
cols = tmp_df.columns
result = (
tmp_df
.group_by("Fold")
.agg(pl.max("mDice"))
.join(tmp_df, on=["Fold", "mDice"])
.select(cols)
)
この方法だと列の並びが変わるので.select
でもとに戻しています。
まとめ
個人的には、method A
の方が簡潔で読みやすいと思います。ただ、method B
も明確な手順で処理を行うため、状況によってはありかも。
もし他に良い方法や改善案があれば、ぜひコメントで教えてください!
バージョン
このコードは以下のバージョンで動作確認済みです:
-
Polars
: 1.15.0