9
5

More than 1 year has passed since last update.

DataFrame を Validation する pandera 入門

Posted at

はじめに

Python を用いてデータ分析を行うにあたりよく使われるライブラリとして pandas があります。
pandas は大変使い勝手の良いライブラリですが、多くの場合データを丸ごと pd.DataFrame 型で保持するため「どのような列を持っているのか」、「各列がどのような型か」、「各列の値にどのような値が入りうるのか」等がソースコードを一見しただけでは分からないことが多いです。
結果として処理がブラックボックス化してしまい、デバッグコストの増加やコードの可読性低下といった問題を生じさせることがあります。
この問題への解決策の一つとして、本記事ではデータフレームのバリデーション機能を提供するライブラリである pandera を紹介します。

pandera とは

データ処理パイプラインの可読性とロバストさを高めるために dataframe に対してデータ検証を行う機能を提供するライブラリです。

主に以下の機能を提供します。(上記ドキュメントより引用。一部抜粋。)

  • スキーマを定義することで、さまざまなデータフレームの型を検証できる
  • データフレームのカラムや値をチェックできる
  • pydantic のようなクラスベースの API でスキーマモデルを定義できる
  • pydantic, fastapi, mypy と言った Python ツールと統合できる

個人的には pydantic のようにクラスベース API でモデルを定義できる点がありがたいと感じています。

インストール

pip でインストールが可能です。

pip install pandera

使い方

DataFrameSchema によるバリデーション

公式チュートリアルより抜粋します。

import pandas as pd
import pandera as pa

# バリデーション用のデータ
df = pd.DataFrame({
    "column1": [1, 4, 0, 10, 9],
    "column2": [-1.3, -1.4, -2.9, -10.1, -20.4],
    "column3": ["value_1", "value_2", "value_3", "value_2", "value_1"],
})

# スキーマ定義
schema = pa.DataFrameSchema({
    "column1": pa.Column(int, checks=pa.Check.le(10)),
    "column2": pa.Column(float, checks=pa.Check.lt(-1.2)),
    "column3": pa.Column(str, checks=[
        pa.Check.str_startswith("value_"),
        # series の入力を受け取り boolean か boolean 型の series を返すカスタムチェックメソッドを定義
        pa.Check(lambda s: s.str.split("_", expand=True).shape[1] == 2)
    ]),
})

validated_df = schema(df)
print(validated_df)
   column1  column2  column3
0        1     -1.3  value_1
1        4     -1.4  value_2
2        0     -2.9  value_3
3       10    -10.1  value_2
4        9    -20.4  value_1

事前にスキーマを定義しておき、データフレームをスキーマに入力するとバリデーションされたデータフレームが出力されます。

SchemaModel によるバリデーション

続いて SchemaModel を利用する使い方を見てみます。こちらもチュートリアルより抜粋します。

from pandera.typing import Series

class Schema(pa.SchemaModel):

    column1: Series[int] = pa.Field(le=10)
    column2: Series[float] = pa.Field(lt=-1.2)
    column3: Series[str] = pa.Field(str_startswith="value_")

    @pa.check("column3")
    def column_3_check(cls, series: Series[str]) -> Series[bool]:
        """Check that column3 values have two elements after being split with '_'"""
        return series.str.split("_", expand=True).shape[1] == 2

Schema.validate(df)

カスタムチェックメソッドが lambda ではなくデコレータ付きメソッドで実装されていますが、書き方に大きな違いはないことが分かります。

なお pandera.Field が取る主要な引数には以下があります。

パラメータ 説明
nullable 列に null を許容するか
unique 列にユニーク制約を課すか
coerce 型を強制するか
ignore_na 型チェックの際に null を無視するか
eq 指定した値と等しいか
ge 指定した値より大きいか
gt 指定した値以上か
le 指定した値より小さいか
lt 指定した値以下か
ne 要素を持たないか
in_range 指定した最小値、最大値の範囲内か
isin 指定したリストの範囲内か
str_contains 指定した文字列を含むか
str_startswith 指定した文字列から始まるか
str_endswith 指定した文字列で終わるか
str_length 文字長が指定した最小値、最大値の範囲内か

実践的な使い方

簡単に使い方を理解したところで Titanic のデータセットを読み込み、加工する処理を試してみます。

データの読み込み

まずはスキーマ定義をしない場合を考えます。

import pandas as pd


def load_data(filepath: str) -> pd.DataFrame:
    df = pd.read_csv(filepath)
    return df


df = load_data("./train.csv")

ファイルからデータを読み込んでいるので当然と言えば当然ですが、データフレームの中身がどうなっているのかは分かりません。

続いてスキーマ定義をした場合を考えます。

from typing import Optional

import pandas as pd
import pandera as pa
from pandera.typing import Series, DataFrame


class TitanicSchema(pa.SchemaModel):
    PassengerId: Series[int] = pa.Field(nullable=False, unique=True)
    Survived: Optional[Series[int]] = pa.Field(nullable=True, isin=(0, 1))
    Pclass: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3))
    Name: Series[str] = pa.Field(nullable=False)
    Sex: Series[str] = pa.Field(nullable=False, isin=("male", "female"))
    Age: Series[float] = pa.Field(nullable=True, in_range={"min_value": 0, "max_value": 100})
    SibSp: Series[int] = pa.Field(nullable=False, ge=0, le=10)
    Parch: Series[int] = pa.Field(nullable=False, ge=0, le=10)
    Ticket: Series[str] = pa.Field(nullable=False)
    Fare: Series[float] = pa.Field(nullable=True, ge=0)
    Cabin: Series[str] = pa.Field(nullable=True)
    Embarked: Series[str] = pa.Field(nullable=True, str_length=1, isin=("S", "C", "Q"))
    
    class Config:
        strict = True


def load_dataset(filepath:str) -> DataFrame[TitanicSchema]:
    df = pd.read_csv(filepath)
    df = TitanicSchema.validate(df)
    return df


df = load_data("./train.csv")

データセットが持つカラムとそれぞれのカラムの情報がスキーマとして定義されたため、データフレームの中身がある程度分かるようになりました。
また、読み込み直後にバリデーションされているため、スキーマとして定義された内容を満たしたデータフレームであることが保証されています。

補足

上記の例では TitanicSchema クラスのメンバ変数に Config クラスを定義して strict=True を設定しています。
このように pa.SchemaModelConfig クラスを登録することでメタ情報を定義することができます。

デフォルトではSchemaModel に登録したカラムが存在しない場合はバリデーションでエラーが出ますが、登録されていないカラムを持っていてもエラーが出ません。
登録されていないカラムを持っていた場合にもエラーを出すために、strict=True を設定しています。

データの加工

続いてデータを加工してみます。加工のプロセスは下記の notebook を参考にさせていただきました。
本記事においては加工処理そのものは重要ではないので流し読みいただいて問題ありません。

先ほどと同じように、まずはスキーマ定義しない場合を考えます。

import numpy as np
import pandas as pd


def transform(df: pd.DataFrame) -> pd.DataFrame:
    df["Sex"] = df["Sex"].str.match("male").map(int)

    df["Title"] = df["Name"].str.extract(" ([A-Za-z]+)\.", expand=False)
    df["Title"] = df["Title"].replace(
        ["Lady", "Countess", "Capt", "Col", "Don", "Dr", "Major", "Rev", "Sir", "Jonkheer", "Dona"], "Rare"
    )
    df["Title"] = df["Title"].replace("Mlle", "Miss")
    df["Title"] = df["Title"].replace("Ms", "Miss")
    df["Title"] = df["Title"].replace("Mme", "Mrs")
    df["Title"] = df["Title"].map({"Mr": 1, "Miss": 2, "Mrs": 3, "Master": 4, "Rare": 5})

    guess_ages = np.zeros((2, 3))
    for i in range(2):
        for j in range(3):
            guess_df = df[(df["Sex"] == i) & (df["Pclass"] == j + 1)]["Age"].dropna()
            age_guess = guess_df.median()
            guess_ages[i, j] = int(age_guess / 0.5 + 0.5) * 0.5
    for i in range(2):
        for j in range(3):
            df.loc[(df["Age"].isnull()) & (df["Sex"] == i) & (df["Pclass"] == j + 1), "Age"] = guess_ages[i, j]
    df["Age"] = df["Age"].astype(int)

    df.loc[df["Age"] <= 16, "Age"] = 0
    df.loc[(df["Age"] > 16) & (df["Age"] <= 32), "Age"] = 1
    df.loc[(df["Age"] > 32) & (df["Age"] <= 48), "Age"] = 2
    df.loc[(df["Age"] > 48) & (df["Age"] <= 64), "Age"] = 3
    df.loc[df["Age"] > 64, "Age"] = 4

    df["FamilySize"] = df["SibSp"] + df["Parch"] + 1

    df["IsAlone"] = df["FamilySize"].map(lambda x: 1 if x == 1 else 0)

    df = df.drop(["Ticket", "Cabin", "PassengerId", "Name"], axis=1)

    return df


df = load_data("./train.csv")
df = transform(df)

加工は欠損補完、ビニング、型変換、列同士の演算、不要列の削除などさまざまな処理が含まれます。
これらの処理を経由した結果、最終的に得られるデータフレームがどのような状態になっているのかがすぐには分かり辛いと思います。

続いてスキーマ定義をした場合を考えます。

class TransformedTitanicSchema(pa.SchemaModel):
    Survived: Optional[Series[int]] = pa.Field(nullable=True, isin=(0, 1))
    Pclass: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3))
    Sex: Series[int] = pa.Field(nullable=False, isin=(0, 1))
    Age: Series[int] = pa.Field(nullable=False, isin=(0, 1, 2, 3, 4))
    SibSp: Series[int] = pa.Field(nullable=False, ge=0, le=10)
    Parch: Series[int] = pa.Field(nullable=False, ge=0, le=10)
    Fare: Series[float] = pa.Field(nullable=True, ge=0)
    Embarked: Series[str] = pa.Field(nullable=True, str_length=1, isin=("S", "C", "Q"))
    Title: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3, 4, 5))
    FamilySize: Series[int] = pa.Field(nullable=False, ge=0, le=15)
    IsAlone: Series[int] = pa.Field(nullable=False, isin=(0, 1))
    
    class Config:
        strict = True
    

def transform(df: DataFrame[TitanicSchema]) -> DataFrame[TransformedTitanicSchema]:
    df["Sex"] = df["Sex"].str.match("male").map(int)

    df["Title"] = df["Name"].str.extract(" ([A-Za-z]+)\.", expand=False)
    df["Title"] = df["Title"].replace(
        ["Lady", "Countess", "Capt", "Col", "Don", "Dr", "Major", "Rev", "Sir", "Jonkheer", "Dona"], "Rare"
    )
    df["Title"] = df["Title"].replace("Mlle", "Miss")
    df["Title"] = df["Title"].replace("Ms", "Miss")
    df["Title"] = df["Title"].replace("Mme", "Mrs")
    df["Title"] = df["Title"].map({"Mr": 1, "Miss": 2, "Mrs": 3, "Master": 4, "Rare": 5})

    guess_ages = np.zeros((2, 3))
    for i in range(2):
        for j in range(3):
            guess_df = df[(df["Sex"] == i) & (df["Pclass"] == j + 1)]["Age"].dropna()
            age_guess = guess_df.median()
            guess_ages[i, j] = int(age_guess / 0.5 + 0.5) * 0.5
    for i in range(2):
        for j in range(3):
            df.loc[(df["Age"].isnull()) & (df["Sex"] == i) & (df["Pclass"] == j + 1), "Age"] = guess_ages[i, j]
    df["Age"] = df["Age"].astype(int)

    df.loc[df["Age"] <= 16, "Age"] = 0
    df.loc[(df["Age"] > 16) & (df["Age"] <= 32), "Age"] = 1
    df.loc[(df["Age"] > 32) & (df["Age"] <= 48), "Age"] = 2
    df.loc[(df["Age"] > 48) & (df["Age"] <= 64), "Age"] = 3
    df.loc[df["Age"] > 64, "Age"] = 4

    df["FamilySize"] = df["SibSp"] + df["Parch"] + 1

    df["IsAlone"] = df["FamilySize"].map(lambda x: 1 if x == 1 else 0)

    df = df.drop(["Ticket", "Cabin", "PassengerId", "Name"], axis=1)
    
    df = TransformedTitanicSchema.validate(df)    
    
    return df 


df = load_data("./train.csv")
df = transform(df)

加工処理は変わっていませんが、処理の中身を完全に読み解かなくてもある程度どのような値が入るかが分かるようになりました。
また、加工処理終了直後にデータのバリデーションをおこなっているので、加工によって想定外の値が混入しないことが保証されるようになりました。

読み込みと加工

まとめると、下記の通りになります。
なお、上記のソースコードはメソッドの最後に SchemaModel.validate を実行して型の確認をしていましたが、これを毎回書くのは面倒です。
タイプヒントのついたメソッドにデコレータ @pa.check_types を付与することで自動的にバリデーションを実施してくれるようになります。

import numpy as np
import pandas as pd
import pandera as pa
from typing import Optional
from pandera.typing import Series, DataFrame


class TitanicSchema(pa.SchemaModel):
    PassengerId: Series[int] = pa.Field(nullable=False, unique=True)
    Survived: Optional[Series[int]] = pa.Field(nullable=True, isin=(0, 1))
    Pclass: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3))
    Name: Series[str] = pa.Field(nullable=False)
    Sex: Series[str] = pa.Field(nullable=False, isin=("male", "female"))
    Age: Series[float] = pa.Field(nullable=True, in_range={"min_value": 0, "max_value": 100})
    SibSp: Series[int] = pa.Field(nullable=False, ge=0, le=10)
    Parch: Series[int] = pa.Field(nullable=False, ge=0, le=10)
    Ticket: Series[str] = pa.Field(nullable=False)
    Fare: Series[float] = pa.Field(nullable=True, ge=0)
    Cabin: Series[str] = pa.Field(nullable=True)
    Embarked: Series[str] = pa.Field(nullable=True, str_length=1, isin=("S", "C", "Q"))

    class Config:
        strict = True


class TransformedTitanicSchema(pa.SchemaModel):
    Survived: Optional[Series[int]] = pa.Field(nullable=True, isin=(0, 1))
    Pclass: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3))
    Sex: Series[int] = pa.Field(nullable=False, isin=(0, 1))
    Age: Series[int] = pa.Field(nullable=True, isin=(0, 1, 2, 3, 4))
    SibSp: Series[int] = pa.Field(nullable=False, ge=0, le=10)
    Parch: Series[int] = pa.Field(nullable=False, ge=0, le=10)
    Fare: Series[float] = pa.Field(nullable=True, ge=0)
    Embarked: Series[str] = pa.Field(nullable=True, str_length=1, isin=("S", "C", "Q"))
    Title: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3, 4, 5))
    FamilySize: Series[int] = pa.Field(nullable=False, ge=0, le=15)
    IsAlone: Series[int] = pa.Field(nullable=False, isin=(0, 1))

    class Config:
        strict = True


@pa.check_types
def load_dataset(filepath: str) -> DataFrame[TitanicSchema]:
    df = pd.read_csv(filepath)
    return df


@pa.check_types
def transform(df: DataFrame[TitanicSchema]) -> DataFrame[TransformedTitanicSchema]:
    df["Sex"] = df["Sex"].str.match("male").map(int)

    df["Title"] = df["Name"].str.extract(" ([A-Za-z]+)\.", expand=False)
    df["Title"] = df["Title"].replace(
        ["Lady", "Countess", "Capt", "Col", "Don", "Dr", "Major", "Rev", "Sir", "Jonkheer", "Dona"], "Rare"
    )
    df["Title"] = df["Title"].replace("Mlle", "Miss")
    df["Title"] = df["Title"].replace("Ms", "Miss")
    df["Title"] = df["Title"].replace("Mme", "Mrs")
    df["Title"] = df["Title"].map({"Mr": 1, "Miss": 2, "Mrs": 3, "Master": 4, "Rare": 5})

    guess_ages = np.zeros((2, 3))
    for i in range(2):
        for j in range(3):
            guess_df = df[(df["Sex"] == i) & (df["Pclass"] == j + 1)]["Age"].dropna()
            age_guess = guess_df.median()
            guess_ages[i, j] = int(age_guess / 0.5 + 0.5) * 0.5
    for i in range(2):
        for j in range(3):
            df.loc[(df["Age"].isnull()) & (df["Sex"] == i) & (df["Pclass"] == j + 1), "Age"] = guess_ages[i, j]
    df["Age"] = df["Age"].astype(int)

    df.loc[df["Age"] <= 16, "Age"] = 0
    df.loc[(df["Age"] > 16) & (df["Age"] <= 32), "Age"] = 1
    df.loc[(df["Age"] > 32) & (df["Age"] <= 48), "Age"] = 2
    df.loc[(df["Age"] > 48) & (df["Age"] <= 64), "Age"] = 3
    df.loc[df["Age"] > 64, "Age"] = 4

    df["FamilySize"] = df["SibSp"] + df["Parch"] + 1

    df["IsAlone"] = df["FamilySize"].map(lambda x: 1 if x == 1 else 0)

    df = df.drop(["Ticket", "Cabin", "PassengerId", "Name"], axis=1)

    return df


def main():
    df = load_dataset("./train.csv")
    df = transform(df)


if __name__ == "__main__":
    main()

終わりに

本記事ではデータフレームのバリデーション機能を提供するライブラリである pandera を紹介しました。

単純にソースコードの分量だけを見ると倍近くに増えており、コストなくバリデーションや型チェックができるわけではないので、短期間で破棄される前提のソースコードに対しては使う価値はないかもしれません。複数人で長期的にメンテナンスしていく必要のあるソースコードに対しては非常に効果的であるように感じました。

機械学習モデルがプロダクション環境で動くことが比較的当たり前になってきたことから dataclass や pydantic 等のスキーマ定義、タイプヒントを開発に導入している企業が増えている印象です。もしデータフレームにタイプヒントや型チェックが付けられなくて困っているようであれば、 pandera の利用を検討してみるのも一つの手かと思います。

その際に本記事がお役に立てば光栄です。最後までお読みいただきありがとうございました。

参考資料

リポジトリ

ドキュメント

日本語の紹介記事

9
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
5