1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Polarsで他のカラムを維持しながらgroup_by操作を行う方法

Last updated at Posted at 2024-12-24

やりたいこと

Polarsを使って、他のカラムを維持しながらgroup_by操作を行う方法を探しました。具体的には、mDiceを計算し、各FoldごとにmDiceが最大のEpochを取得したい、という内容です。

ぱっと思いつかなかったので、備忘録として記事にまとめます。


データ確認

以下のようなデータを対象としています。ここで、mDiceを計算し、各FoldごとにmDiceが最大の行を抽出したいとします。

Epoch Fold Label1 Label2 Label3
1 1 0.8 0.75 0.85
2 1 0.85 0.78 0.9
3 1 0.88 0.8 0.92
1 2 0.76 0.74 0.78
2 2 0.84 0.8 0.85
3 2 0.9 0.85 0.88
data = {
    "Epoch": [1, 2, 3, 1, 2, 3],
    "Fold": [1,1,1,2,2,2],
    "Label1": [0.8, 0.85, 0.88, 0.76, 0.84, 0.9],
    "Label2": [0.75, 0.78, 0.8, 0.74, 0.8, 0.85],
    "Label3": [0.85, 0.9, 0.92, 0.78, 0.85, 0.88],
}
df = pl.from_dict(data)

method A

最も簡潔な方法として、filterを活用します。この方法では、overを使ってFoldごとにmDiceの最大値を計算し、その行を抽出します。

result = (
    df
    .with_columns(mDice = pl.mean_horizontal(pl.exclude("Epoch", "Fold")))
    .filter(pl.col("mDice")==pl.max("mDice").over("Fold"))
)

水平方向の平均には便利な.mean_horizontal()を使用します。

method B

group_bymDiceの最大値を計算し、それに対応する行をjoinで取得する方法です。この方法は、手続き的で構造が明確です。

tmp_df = df.with_columns(mDice = pl.mean_horizontal(pl.exclude("Epoch", "Fold")))
cols = tmp_df.columns
result = (
    tmp_df
    .group_by("Fold")
    .agg(pl.max("mDice"))
    .join(tmp_df, on=["Fold", "mDice"])
    .select(cols)
)

この方法だと列の並びが変わるので.selectでもとに戻しています。

まとめ

個人的には、method Aの方が簡潔で読みやすいと思います。ただ、method Bも明確な手順で処理を行うため、状況によってはありかも。

もし他に良い方法や改善案があれば、ぜひコメントで教えてください!


バージョン

このコードは以下のバージョンで動作確認済みです:

  • Polars: 1.15.0
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?