はじめに
pandasのデータフレーム(pandas.DataFrame
)は、Pythonで機械学習やデータ分析をする多くの人に利用されています。そんなPythonにおけるデータフレームですが「不正な値が格納されてしまう」というような課題が挙げられます。それに加え、特に問題なのが、Pythonが動的型付け言語であることにより、不正な値が格納されていたとしてもエラーが発生せず、開発者が気づかないという恐れがあるという点です。筆者(@daikikatsuragawa)も、ある項目に対してあり得ない値を格納したものの、その直後に気づけないということがありました。しかし、機械学習やデータ分析において、データフレームにドメインの観点で妥当なデータが格納されていることが非常に重要であると考えられます。このような課題の改善策としてデータフレームのバリデーションを実現するためのライブラリであるpanderaを紹介します。ぜひ、本記事を読んで“pandera入門”をしていただけたら、と思います。
pandera
panderaはデータフレームのバリデーションを実現するPythonのライブラリです。pandasだけでなくdask、modin、pyspark.pandasなどにも対応しています。ただし、本記事におけるデータフレームはpandasのものと限定します。
panderaはオープンソースソフトウェア(OSS)として開発されています。
panderaはデータフレームに格納されている値について、例えば以下のバリデーションを実現します。
- 型は正しい?
- 値は妥当?
- 自然数?(正の整数?)
- 存在しうる日付?
- 文字数は正しい?
- 一意であることを期待する項目に対して重複した値はない?
他にもさまざまな機能がありますが、今回は入門ということで上記のようなバリデーションの実現方法について、基本的なものを紹介します。
panderaでバリデーションを実現する方法は2つあります。
- DataFrameSchema(
pandera.schemas.DataFrameSchema
)を利用する方法 - SchemaModel(
pandera.model.SchemaModel
)を利用する方法
本記事ではこれら2つの利用例について、ダミーデータを使いつつ紹介します。
ダミーデータ
実現したいバリデーションとそれを実現するpanderaの記述方法について、実用的なものを紹介するためには、可能な限り現実的なデータを用いて利用例を紹介することが望ましいです。それゆえ、本記事では以下のキャンペーンより用意されているQiitaの投稿のダミーデータ(QiitadelikaDummy)を利用します。
このキャンペーンを企画しているコネクトデータさんが提供している「delika」はデータ版「GitHub」を目指して開発されたデータ共有プラットフォームです。実際に分析する目的でなくとも、今回の記事のように可能な限り現実的なデータを利用したい場合に有用であり、ありがたいです。
準備
準備として、ダミーデータを読み込んで、中身を確認し、実現したいデータフレームのルール(スキーマ)を整理します。以下、実行環境はGoogle Colaboratoryとします。そして、ダウンロードしたcsvファイルはそれぞれarticles.csv、tags.csv、article_tags.csvと命名し、Google Driveのマイドライブ(My Drive)の直下に作成したQiitadelikaDummyというフォルダに格納したとします。すると、例えばarticles.csvのパスはdrive/My Drive/QiitadelikaDummy /articles.csv
になります。
from google.colab import drive
import pandas as pd
articles_df = pd.read_csv("drive/My Drive/QiitadelikaDummy/articles.csv")
tags_df = pd.read_csv("drive/My Drive/QiitadelikaDummy/tags.csv")
article_tags_df = pd.read_csv("drive/My Drive/QiitadelikaDummy/article_tags.csv")
articles_dfの中身は以下です。
articles_df.head()
article_id | created_at | likes_count | comments_count | url | users | page_views_count | |
---|---|---|---|---|---|---|---|
0 | 1 | 2017-04-06 | 332 | 7 | 771q8q4kk | cb29wVZ6y | 7469 |
1 | 2 | 2012-04-18 | 865 | 8 | t8sm68r4a | K1UwSLNXd | 9862 |
2 | 3 | 2020-05-17 | 435 | 7 | to1qkedii | YOJuk11fm | 5798 |
3 | 4 | 2020-02-29 | 847 | 4 | x5ve640er | fDf7EEckx | 9960 |
4 | 5 | 2019-09-12 | 529 | 2 | ltk7u91u4 | 1YkM1a3-l | 7127 |
また、データを確認した上で、articles_dfのチェック項目は以下とします。
- id
- int型であること
- 値は1以上であること
- created_at
- str型であること
- 日付として解釈が可能であること
- likes_count
- int型であること
- 値は0以上であること
- comments_count
- int型であること
- 値は0以上であること
- url
- str型であること
- データ内で重複した値を持たないこと
- 値は9文字であること
- users
- str型であること
- 値は9文字か
#NAME?
であること
- page_views_count
- int型であること
- 値は0以上であること
tags_dfの中身は以下です。
tags_df.head()
tag_id | tag_name | |
---|---|---|
0 | 1 | Python |
1 | 2 | JavaScript |
2 | 3 | Ruby |
3 | 4 | Rails |
4 | 5 | AWS |
また、データを確認した上で、tags_dfのチェック項目は以下とします。
- id
- int型であること
- 値は1以上であること
- name
- str型であること
article_tags_dfの中身は以下です。
article_tags_df.head()
id | article_id | tag_id | |
---|---|---|---|
0 | 1 | 1 | 144 |
1 | 2 | 1 | 7 |
2 | 3 | 1 | 121 |
3 | 4 | 2 | 53 |
4 | 5 | 2 | 132 |
また、データを確認した上で、article_tags_dfのチェック項目は以下とします。
- id
- int型であること
- 値は1以上であること
- article_id
- int型であること
- 値は1以上であること
- tag_id
- int型であること
- 値は1以上であること
また、実際にこれらのデータフレームを扱うときにマージするとします。article_tags_dfのarticle_idとarticles_dfのid、article_tags_dfのtag_idとtags_dfのidを元にマージが可能です。このとき、articles_dfに情報を足す想定で、article_tags_df、tags_dfと順にマージします。このとき、article_tags_dfのidは削除し、tags_dfのidはtag_id、nameはtag_nameとリネームするとします。
そして、マージしたデータフレームのチェック項目は以下とします。
- article_id
- int型であること
- 値は1以上であること
- created_at
- str型であること
- 日付として解釈が可能であること
- likes_count
- int型であること
- 値は0以上であること
- comments_count
- int型であること
- 値は0以上であること
- url
- str型であること
- 値は9文字であること
- users
- str型であること
- 値は9文字か
#NAME?
であること
- page_views_count
- int型であること
- 値は0以上であること
- tag_id
- int型であること
- 値は1以上であること
- tag_name
- str型であること
※articles_dfのurlのチェック項目として挙げていた「データ内で重複した値を持たないこと」については特定のarticleについて複数のレコードが生成されることから、マージしたデータフレームでは確認しません。
これらのチェック項目に基づいてデータフレームのバリデーションを実現していきます。
これ以降、紹介するスクリプトはGoogle Colaboratoryで実行しています。panderaは以下によりインストールします。
!pip install pandera
DataFrameSchemaを利用する方法
articles_dfのバリデーションは以下のような記述により実現します。
import pandera as pa
def is_date(date_str):
"""
「日付として解釈が可能であること」を確認する関数。
"""
date_format = "%Y-%m-%d"
try:
return bool(datetime.datetime.strptime(date_str, date_format))
except ValueError:
return False
UNKNOWN_USER_VALUE = "#NAME?"
articles_schema = pa.DataFrameSchema({
"id": pa.Column(int, checks=pa.Check.ge(1)),
"created_at": pa.Column(str, checks=[
pa.Check(
lambda g: is_date(g) == True,
element_wise=True)
]
),
"likes_count": pa.Column(int, checks=pa.Check.ge(0)),
"comments_count": pa.Column(int, checks=pa.Check.ge(0)),
"url": pa.Column(str, checks=[
pa.Check(
lambda s: len(s) == len(set(s)),
element_wise=False),
pa.Check(
lambda g: len(g) == 9,
element_wise=True)
]
),
"users": pa.Column(str, checks=[
pa.Check(
lambda g: (len(g) == 9) or (g == UNKNOWN_USER_VALUE),
element_wise=True)
]
),
"page_views_count": pa.Column(int, checks=pa.Check.ge(0)),
},
strict=True
)
articles_df = articles_schema(articles_df)
# 以下のような記述も可能です。
# articles_df = articles_schema.validate(articles_df)
strict=True
と指定することで定義した列名のみを持っていることをチェックします。
tags_dfのバリデーションは以下のような記述により実現します。
tags_schema = pa.DataFrameSchema({
"id": pa.Column(int, checks=pa.Check.ge(1)),
"name": pa.Column(str)
},
strict=True
)
tags_df = tags_schema(tags_df)
# 以下のような記述も可能です。
# tags_df = tags_schema.validate(tags_df)
article_tags_dfのバリデーションは以下のような記述により実現します。
article_tags_schema = pa.DataFrameSchema({
"id": pa.Column(int, checks=pa.Check.ge(1)),
"article_id": pa.Column(int, checks=pa.Check.ge(1)),
"tag_id": pa.Column(int, checks=pa.Check.ge(1)),
},
strict=True
)
article_tags_df = article_tags_schema(article_tags_df)
# 以下のような記述も可能です。
# article_tags_df = article_tags_schema.validate(article_tags_df)
それではarticles_df、tags_df、article_tags_dfをマージします。
renamed_tags_df = tags_df.rename(columns={"id": "tag_id", "name": "tag_name"})
tmp_df = pd.merge(articles_df, article_tags_df[["article_id", "tag_id"]], left_on="id", right_on="article_id")
merged_articles_df = pd.merge(tmp_df, renamed_tags_df, left_on="tag_id", right_on="tag_id")
マージしたデータフレームのバリデーションは以下のような記述により実現します。
merged_articles_schema = pa.DataFrameSchema({
"id": pa.Column(int, checks=pa.Check.ge(1)),
"created_at": pa.Column(str, checks=[
pa.Check(
lambda g: is_date(g) == True,
element_wise=True
)
]
),
"likes_count": pa.Column(int, checks=pa.Check.ge(0)),
"comments_count": pa.Column(int, checks=pa.Check.ge(0)),
"url": pa.Column(str, checks=[
pa.Check(
lambda g: len(g) == 9,
element_wise=True)
]
),
"users": pa.Column(str, checks=[
pa.Check(
lambda g: (len(g) == 9) or (g == UNKNOWN_USER_VALUE),
element_wise=True)
]
),
"page_views_count": pa.Column(int, checks=pa.Check.ge(0)),
"tag_id": pa.Column(int, checks=pa.Check.ge(1)),
"tag_name": pa.Column(str)
},
strict=True
)
merged_articles_df = merged_articles_schema(merged_articles_df)
# 以下のような記述も可能です。
# merged_articles_df = merged_articles_schema.validate(merged_articles_df)
merged_articles_df = merged_articles_df.sort_values(["article_id", "tag_id"])
merged_articles_df.head()
article_id | created_at | likes_count | comments_count | url | users | page_views_count | tag_id | tag_name | |
---|---|---|---|---|---|---|---|---|---|
17 | 1 | 2017-04-06 | 332 | 7 | 771q8q4kk | cb29wVZ6y | 7469 | 7 | 初心者 |
38 | 1 | 2017-04-06 | 332 | 7 | 771q8q4kk | cb29wVZ6y | 7469 | 121 | Elasticsearch |
0 | 1 | 2017-04-06 | 332 | 7 | 771q8q4kk | cb29wVZ6y | 7469 | 144 | Perl |
57 | 2 | 2012-04-18 | 865 | 8 | t8sm68r4a | K1UwSLNXd | 9862 | 53 | PostgreSQL |
97 | 2 | 2012-04-18 | 865 | 8 | t8sm68r4a | K1UwSLNXd | 9862 | 82 | 新人プログラマ応援 |
このような記述により、panderaによるデータフレームのバリデーションが実現されます。
参考までにバリデーションを満たさない場合についても確認します。idが1以上というチェック項目を設定したarticles_dfのバリデーションに対してidが0という不正なレコードを含むデータフレームに対してバリデーションを実施します。
import copy
invalid_df = copy.deepcopy(articles_df)
# 不正な値(0)を格納
invalid_df["id"][0] = 0
invalid_df = articles_schema.validate(invalid_df)
上記の実行は失敗します。そして、以下のように出力されます。
(省略)
SchemaError: <Schema Column(name=id, type=DataType(int64))> failed element-wise validator 0:
<Check greater_than_or_equal_to: greater_than_or_equal_to(1)>
failure cases:
index failure_case
0 0 0
SchemaError
に加え、不正なレコード(failure cases
)、不正である理由が表示されます。
SchemaModelを利用する方法
articles_dfのバリデーションは以下のような記述により実現します。
import pandera as pa
from pandera.typing import Series
def is_date(date_str):
"""
「日付として解釈が可能であること」を確認する関数。
"""
date_format = "%Y-%m-%d"
try:
return bool(datetime.datetime.strptime(date_str, date_format))
except ValueError:
return False
UNKNOWN_USER_VALUE = "#NAME?"
class ArticlesSchema(pa.SchemaModel):
id: Series[int] = pa.Field(ge=1)
created_at: Series[str]
likes_count: Series[int] = pa.Field(ge=0)
comments_count: Series[int] = pa.Field(ge=0)
url: Series[str]
users: Series[str]
page_views_count: Series[int] = pa.Field(ge=0)
@pa.check("created_at")
def check_created_at(cls, series: Series[str]) -> Series[bool]:
return series.map(is_date)
@pa.check("url")
def check_url(cls, series: Series[str]) -> Series[bool]:
return series.map(len) == 9
@pa.check("url")
def check_urls(cls, series: Series[str]) -> bool:
return len(set(series)) == len(list(series))
@pa.check("users")
def check_users(cls, series: Series[str]) -> Series[bool]:
return (series.map(len) == 9) + (series.str.match(UNKNOWN_USER_VALUE))
class Config:
name = "BaseSchema"
strict = True
articles_df = ArticlesSchema.validate(articles_df)
DataFrameSchemaで指定していたstrict=True
(定義した列名のみを持っていることをチェックする設定)について、SchemaModelの場合は、以下のように指定しています。
class Config:
name = "BaseSchema"
strict = True
tags_dfのバリデーションは以下のような記述により実現します。
class TagsSchema(pa.SchemaModel):
id: Series[int] = pa.Field(ge=1)
name: Series[str]
class Config:
name = "BaseSchema"
strict = True
tags_df = TagsSchema.validate(tags_df)
article_tags_dfのバリデーションは以下のような記述により実現します。
class ArticleTagsSchema(pa.SchemaModel):
id: Series[int] = pa.Field(ge=1)
article_id: Series[int] = pa.Field(ge=1)
tag_id: Series[int] = pa.Field(ge=1)
class Config:
name = "BaseSchema"
strict = True
article_tags_df = ArticleTagsSchema.validate(article_tags_df)
それではDataFrameSchemaの方法と同様にarticles_df、tags_df、article_tags_dfをマージします。
renamed_tags_df = tags_df.rename(columns={"id": "tag_id", "name": "tag_name"})
tmp_df = pd.merge(articles_df, article_tags_df[["article_id", "tag_id"]], left_on="id", right_on="article_id")
merged_articles_df = pd.merge(tmp_df, renamed_tags_df, left_on="tag_id", right_on="tag_id")
マージしたデータフレームのバリデーションは以下のような記述により実現します。特にarticles_dfのバリデーションのために作成したArticlesSchemaを継承することが可能です。今回はArticlesSchemaを継承して、tag_id、tag_nameを追加したMergedArticlesSchemaを作成します。
class MergedArticlesSchema(ArticlesSchema):
tag_id: Series[int] = pa.Field(ge=1)
tag_name: Series[str]
class Config:
name = "BaseSchema"
strict = True
merged_articles_df = MergedArticlesSchema.validate(merged_articles_df)
このような記述により、panderaによるデータフレームのバリデーションが実現されます。
また、以下のようにSchemaModelからDataFrameSchemaを生成することも可能です。
merged_articles_schema = MergedArticlesSchema.to_schema()
merged_articles_df = merged_articles_schema(merged_articles_df)
また、SchemaModelを定義しておくと、データフレームを扱う関数を定義した場合、以下のようにデコレーターと型ヒントによるバリデーションも可能です。
from pandera.typing import DataFrame
@pa.check_types
def merge_articles_and_tags(articles_df: DataFrame[ArticlesSchema], tags_df: DataFrame[TagsSchema], article_tags_df: DataFrame[ArticleTagsSchema]) -> DataFrame[MergedArticlesSchema]:
renamed_tags_df = tags_df.rename(columns={"id": "tag_id", "name": "tag_name"})
tmp_df = pd.merge(articles_df, article_tags_df[["article_id", "tag_id"]], left_on="id", right_on="article_id")
merged_articles_df = pd.merge(tmp_df, renamed_tags_df, left_on="tag_id", right_on="tag_id")
return merged_articles_df.sort_values(["article_id", "tag_id"])
merged_articles_df = merge_articles_and_tags(articles_df=articles_df, tags_df=tags_df, article_tags_df=article_tags_df)
まとめ
データフレームのバリデーションを実現するためのライブラリであるpanderaを紹介しました。そして、ダミーデータを使ってpanderaでバリデーションを実現する方法を2つ紹介しました。
- DataFrameSchema(
pandera.schemas.DataFrameSchema
)を利用する方法 - SchemaModel(
pandera.model.SchemaModel
)を利用する方法
これらの方法でpanderaを利用することで、ドメインの観点で妥当なデータが格納されていると信頼できるデータフレームの扱いが期待されます。
(追記)
以下の記事にリンクを掲載していただきました!ありがとうございます!
- 【毎日自動更新】データに関する記事を書こう! LGTMランキング!(2022/04/18 10:00)
- 【Python】Qiita 週間 LGTM 数ランキング【自動更新】(2022/04/18 20:00)
- 【Python】Qiita デイリー LGTM 数ランキング【自動更新】(2022/04/18 20:00)
- Qiita デイリー LGTM 数ランキング【自動更新】(2022/04/19 03:01)
付録
delika Python Clientからダウンロード&バリデーション(+型の変換)
本記事ではダウンロードしたデータを利用しましたが、delika Python Clientを使ってPythonのみ(結果的に条件付き)でデータ取得からバリデーションまでを実施してみます。また新たに得られた知見もあるためそちらも紹介します。
上記URLのドキュメントを参考に以下により、Pythonのdelika client(およびdelika内のPandas)のインストールができます。
!pip install --extra-index-url=https://docs.delika.io/python/ delika
!pip install --extra-index-url=https://docs.delika.io/python/ delika[DataFrame]
以下により認証関係の処理を実施します。
import delika
token = delika.new_token()
token.save()
client = delika.new_client(token)
Use a browser to open the page https://api.delika.io/v1/auth and paste the result JSON in the browser after you sign in.
Input JSON:
表示されるURLにアクセス(ポチッ👇)します。画面の指示に従い、ログインなどを完了させます。そして、表示される画面の下部にあるJSONをコピーして、Input JSON
に入力します。このJSONおよびアクセストークンなどに有効期限があるため注意が必要です。ここがPythonのみでできるかと思いましたが「結果的に条件付き」だった点です。
以下によりデータフレームとしてダウンロードします。
import delika.pandas
tags_df = delika.pandas.read_delika_data(account_name="qiita_delika_article_campaign", dataset_name="QiitadelikaDummy", data_name="article_tags.csv", client= client)
tags_df.head()
id | name | |
---|---|---|
0 | 1.0 | Python |
1 | 2.0 | JavaScript |
2 | 3.0 | Ruby |
3 | 4.0 | Rails |
4 | 5.0 | AWS |
ダウンロードしたデータフレームを確認したところ、idがfloat型(もしくはfloat64など関連する型)のようです。
以下によりSchemaModelによりバリデーションを試みてみます。
import pandera as pa
from pandera.typing import Series
class TagsSchema(pa.SchemaModel):
id: Series[int] = pa.Field(ge=1)
name: Series[str]
class Config:
name = "BaseSchema"
strict = True
tags_df = TagsSchema.validate(tags_df)
SchemaError
が生じました。意図している型に対してダウンロードしてきたデータフレームの型が異なるようです。
(省略)
SchemaError: expected series 'id' to have type int64, got float64
後述する方法で上記のSchemaError
は解決するのですが、nameについてもSchemaError
が生じました。こちらも同様に意図している型に対してダウンロードしてきたデータフレームの型が異なるようです。
(省略)
SchemaError: expected series 'name' to have type str, got string
これらは2点のSchemaError
は以下のような解決が可能です。pa.Field
に対してcoerce=True
を指定します。これにより変換できそうな型であれば変換されます。
import pandera as pa
from pandera.typing import Series
class TagsSchema(pa.SchemaModel):
id: Series[int] = pa.Field(ge=1, coerce=True)
name: Series[str] = pa.Field(coerce=True)
class Config:
name = "BaseSchema"
strict = True
tags_df = TagsSchema.validate(tags_df)
tags_df.head()
id | name | |
---|---|---|
0 | 1 | Python |
1 | 2 | JavaScript |
2 | 3 | Ruby |
3 | 4 | Rails |
4 | 5 | AWS |
意図通りの型に変換されているようです。coerce=True
により上記の問題は解決するのですが、前段の処理が意図通りでない可能性もあるため注意が必要です。
このようにdelikaに限らず外部(AWSなど)からデータをダウンロードしてくる場面がしばしばあるかと思います。例えば以下のように、バリデーションが通っていない怪しい状態のデータフレームを一時的にも保持せず、SchemaModel.validate()
の引数に直接入れてしまうのもいいでしょう。
import pandera as pa
from pandera.typing import Series
class TagsSchema(pa.SchemaModel):
id: Series[int] = pa.Field(ge=1, coerce=True)
name: Series[str] = pa.Field(coerce=True)
class Config:
name = "BaseSchema"
strict = True
tags_df = TagsSchema.validate(
delika.pandas.read_delika_data(
account_name="qiita_delika_article_campaign",
dataset_name="QiitadelikaDummy",
data_name="tags.csv",
client=client
)
)