LoginSignup
0
0

polarsでtrain_val_test分割する自作関数を作成

Last updated at Posted at 2024-05-19

polarsでtrain_val_test分割する自作関数を作成

コード

import lightgbm as lgm
import plotly
import polars as pl
from icecream import ic
@pl.api.register_dataframe_namespace("tvt_split")
class Shuffle_tvt_split:
    def __init__(
        self,
        df: pl.DataFrame,
    ):
        self._df = df
        self.df_length = df.select(pl.len()).item()

    def train_length(self, params: dict[str : bool | float] = None) -> int:

        if params is None:
            params = {"shuffle": True, "seed": 0, "val_size": 0.25, "test_size": 0.25}

        return (
            self.df_length
            - int(self.df_length * params["val_size"])
            - int(self.df_length * params["test_size"])
        )

    def val_length(self, params: dict[str : bool | float] = None) -> int:

        if params is None:
            params = {"shuffle": True, "seed": 0, "val_size": 0.25, "test_size": 0.25}

        return int(self.df_length * params["val_size"])

    def test_length(self, params: dict[str : bool | float] = None) -> int:

        if params is None:
            params = {"shuffle": True, "seed": 0, "val_size": 0.25, "test_size": 0.25}

        return int(self.df_length * params["test_length"])

    def train(self, params: dict[str : bool | float] = None) -> pl.DataFrame:

        if params is None:
            params = {"shuffle": True, "seed": 0, "val_size": 0.25, "test_size": 0.25}

        df = (
            self._df.select(pl.all().shuffle(seed=params["seed"]))
            if params["shuffle"] == True
            else self._df
        )

        return (
            df.with_row_index(name="n")
            .filter((pl.col("n") >= 0) & (pl.col("n") < self.train_length(params)))
            .drop("n")
        )

    def val(self, params: dict[str : bool | float] = None) -> pl.DataFrame:

        if params is None:
            params = {"shuffle": True, "seed": 0, "val_size": 0.25, "test_size": 0.25}

        df = (
            self._df.select(pl.all().shuffle(seed=params["seed"]))
            if params["shuffle"] == True
            else self._df
        )

        return (
            df.with_row_index(name="n")
            .filter(
                (pl.col("n") >= self.train_length(params))
                & (pl.col("n") < self.train_length(params) + self.val_length(params))
            )
            .drop("n")
        )

    def test(self, params: dict[str : bool | float] = None) -> pl.DataFrame:

        if params is None:
            params = {"shuffle": True, "seed": 0, "val_size": 0.25, "test_size": 0.25}

        df = (
            self._df.select(pl.all().shuffle(seed=params["seed"]))
            if params["shuffle"] == True
            else self._df
        )

        return (
            df.with_row_index(name="n")
            .filter(
                (pl.col("n") >= self.train_length(params) + self.val_length(params))
                & (pl.col("n") < self.df_length)
            )
            .drop("n")
        )

    def tvt(self, params: dict[str : bool | float] = None) -> dict[str : pl.DataFrame]:

        if params is None:
            params = {"shuffle": True, "seed": 0, "val_size": 0.25, "test_size": 0.25}

        return {
            "train": self.train(params),
            "val": self.val(params),
            "test": self.test(params),
        }

使い方

params = {"shuffle": False, "seed": 0, "val_size": 0.25, "test_size": 0.25}
sample_df = pl.DataFrame(
    data=["aaa", "bbb", "ccc", "ddd", "eee", "fff"],
    schema=[("txt", pl.String)],
)

ic(sample_df.tvt_split.train(params))
ic(sample_df.tvt_split.val(params))
ic(sample_df.tvt_split.test(params))
ic| sample_df.tvt_split.train(params): shape: (4, 1)
                                       ┌─────┐
                                       │ txt │
                                       │ --- │
                                       │ str │
                                       ╞═════╡
                                       │ aaa │
                                       │ bbb │
                                       │ ccc │
                                       │ ddd │
                                       └─────┘
ic| sample_df.tvt_split.val(params): shape: (1, 1)
                                     ┌─────┐
                                     │ txt │
                                     │ --- │
                                     │ str │
                                     ╞═════╡
                                     │ eee │
                                     └─────┘
ic| sample_df.tvt_split.test(params): shape: (1, 1)
                                      ┌─────┐
                                      │ txt │
                                      │ --- │
                                      │ str │
                                      ╞═════╡
                                      │ fff │
                                      └─────┘

参考

Extending the API — Polars documentation

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0