0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Scikit-learn v1.4で強化された新機能 謎のmetadata routingを探る

Last updated at Posted at 2024-02-17

機械学習やってる人なら必ず使うsklearn、リリースハイライトの記事はとても面白いので毎回楽しみに読んでいます。
v1.3でmetadata routingという仕組みが導入されました。出てきた頃はサンプルコードも少なく、日本語の情報も少ないのでよくわからんが大体こういうものかなと浅い理解で済ませていたのですが、よく調べると自分が理解していたものと全然違いました。公式ドキュメントのサンプルコードはやや難しいので、極限まで簡単な例で記事にしてみます。

import sklearn
print(sklearn.__version__)

# 1.4.1.post1

まず適当なデータを用意して線形回帰を行うと、こんな感じ。

from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
x = [[1], [2], [3], [4], [5]]
y = [1.1, 1.9, 3.3, 3.7, 5]

lr = LinearRegression()
lr.fit(x, y)
pred = lr.predict(x)
plt.scatter(x, y)
plt.plot(x, pred)
plt.show()

image.png

で、ここでsample_weightを与えるとする。
このsample_weightはサンプルごとの重要度のようなもので、ウエイトの高いデータに対してより当てはまるようになる。

以下のように3番目のデータに高いウエイトを与えてやるとその点に強く当てはまる。

lr = LinearRegression()
lr.fit(x, y, sample_weight=[1, 1, 100, 1, 1])
pred = lr.predict(x)
plt.scatter(x, y)
plt.plot(x, pred)
plt.show()

image.png

この時勘違いしやすいのはこのsample_weightは回帰のハイパーパラメーターではないということ。パラメーターはLinearRegression()でインスタンスを作るときに与えるので、fit時に与えるこれはデータである。
データではあるが、説明変数でも目的変数でもない、回帰のアルゴリズムで使っているわけではない(回帰する前に前処理で使ってる https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/linear_model/_base.py#L582-#L607 )データに対するデータということになるので、メタデータと呼んでいるようだ。

で、v1.3で導入されv1.4で強化されたmetadata routingを使うと以下のように、この学習機ではfitの時にsample_weightのデータを使いますよ、みたいな宣言ができるようになった。
この例だと何の意味もないが、とりあえず以下のような書き方が出来る。

sklearn.set_config(enable_metadata_routing=True)  # メタデータルーティングを有効化

lr = LinearRegression()
lr.set_fit_request(sample_weight=True)  # 学習機側でメタデータの利用を宣言

lr.fit(x, y, sample_weight=[1, 1, 100, 1, 1])

pred = lr.predict(x)
plt.scatter(x, y)
plt.plot(x, pred)
plt.show()  # 上と同じ図が表示

じゃあ実際どういう時に使うのかというと、PipelineとかGridsearchみたいなfitやscoreといったメソッドを連鎖させて実行させるプロセスで使う。
例えば以下のように欠損値補完してから推定するパイプライン学習機を構築した場合、fit時にメタデータを素直に渡すとエラーが出る。
(とはいえ、この渡し方じゃダメだからこう書けというかなり親切なエラーメッセージになっている)

from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
sklearn.set_config(enable_metadata_routing=False)  # 一旦falseにしておく

x = [[1], [2], [3], [4], [5], [None]]
y = [1.1, 1.9, 3.3, 3.7, 5, 3.0]

pipe = make_pipeline(SimpleImputer(), LinearRegression())

pipe.fit(x, y, sample_weight=[1, 1, 100, 1, 1, 100])
pipe.predict(x)

出力

ValueError                                Traceback (most recent call last)
Cell In[13], line 1
----> 1 pipe.fit(x, y, sample_weight=[1, 1, 100, 1, 1, 100])
      2 pipe.predict(x)

ValueError: Pipeline.fit does not accept the sample_weight parameter. You can pass parameters to specific steps of your pipeline using the stepname__parameter format, e.g. `Pipeline.fit(X, y, logisticregression__sample_weight=sample_weight)`.

というわけで、メタデータルーティングを有効化し、fit時にsample_weightが必要な学習機はset_fit_requestを書くと、pipeline側がよしなにデータをルーティングしてくれるようになったよ、ということだ。

sklearn.set_config(enable_metadata_routing=True)

pipe = make_pipeline(
    SimpleImputer(), 
    LinearRegression().set_fit_request(sample_weight=True))

pipe.fit(x, y, sample_weight=[1, 1, 100, 1, 1, 100])
pred = pipe.predict(x)
plt.scatter(x,y)
plt.plot(x,pred)
plt.show()  # 図が表示される

繰り返しにはなるが、インスタンス生成時に必要な引数をリクエストするわけではないので、ハイパーパラメーターの取り回しができるわけではない。現時点ではほぼsample_weightをルーティングする専用機能のようです。
私は工学系で説明変数に含まないそれよりもメタなデータを持つデータってものに遭遇したことが無いのでいまいち使い道をイメージできないのですが、Kaggleに取り組んでる人や金融、社会学ドメインのデータ分析される方は使ったりしてるのでしょうか?

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?