1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

OOFの予測値でConfident Learning (cleanlab) を使おう!

Last updated at Posted at 2021-10-12

はじめに

Confident Learningという手法を使うことで、分類タスクにおける、データセットの中の間違ったラベル(noisy label)のサンプルを検出することができる。詳しい解説が既にあり、Confident Learningを実行するライブラリにcleanlabというのものあるので、ここでは機械学習で後付でcleanlabを使う小技の紹介をする。

cleanlab

cleanlabはConfident Learningを実行できるライブラリで、scikit-learn like APIのモデルクラスをラップして使うことができる。

from sklearn.base import BaseEstimator

class YourFavoriteModel(BaseEstimator): # Inherits sklearn base classifier
    def __init__(self, ):
        pass
    def fit(self, X, y, sample_weight=None):
        pass
    def predict(self, X):
        pass
    def predict_proba(self, X):
        pass
    def score(self, X, y, sample_weight=None):
        pass

# Now you can use your model with `cleanlab`. Here's one example:
from cleanlab.classification import LearningWithNoisyLabels

lnl = LearningWithNoisyLabels(clf=YourFavoriteModel())
lnl.fit(train_data, train_labels_with_errors)

ただ、これをPyTorchで実装するのはメンドクサイ。。。こちらにサンプルコードがあるが、既に学習を回している人が後からサンプルのようにcleanlab用に書き換えるのはツライ。

しかし! cross validationをしている方ならOOF(Out Of Fold)でtrainデータの予測値(正確には各ラベルの確率)を出力していると思われる。このOOFの値y_oofさえあれば、cleanlabのデータクリーニング部分は実行できるので、小技として紹介しよう。

train data に対するラベルのOOF予測確率y_oofは得られているとする。まず次のクラスを作る。

from sklearn.base import BaseEstimator

class CleanModel(BaseEstimator): # Inherits sklearn base classifier
    def __init__(self):
        pass
    def fit(self, X, y, sample_weight = None):
        pass
    def predict(self, X):
        pass
    def predict_proba(self, X):
        pass
    def score(self, X, y, sample_weight = None):
        pass

上で作ったクラスをラップして、fitを実行する。

# 確率に規格化できてない場合は規格化する
psx = softmax(y_oof, axis=1)

# モデルをラップする
clf = LearningWithNoisyLabels(
    clf=CleanModel(),
    prune_method='prune_by_class' # ここは自由に選ぶ
)
clf.fit(
    X=np.arange(len(train_df)),  # データのid
    s=train_df[y].values,  # クラスラベルの ndarray, shape (n_samples, )
    psx=psx,
    noise_matrix=None,
    inverse_noise_matrix=None,
)

noise_mask変数にノイズラベルと推測されるサンプルのインデックスが格納されています。

cleaned_train_df = train_df.iloc[~clf.noise_mask]

あとはこの新しいデータで再学習を実行すればOKです。再学習は自身の既存コードで実行すればよいです。

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?