0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Dataset Cartography】DNN の分類モデルが誤分類しやすいサンプルの特定

Last updated at Posted at 2024-09-10

はじめに

  • Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics (論文リンク) の要点をまとめたメモです.
  • 自分用の簡単なまとめであり,説明はほとんどありません.
  • 内容は時間があるときに補充します.

概要

Dataset Cartography は,DNN の分類モデルが誤分類しやすいサンプルを特定するための手法です.誤分類の原因には,以下の二つの可能性があります:

  1. サンプル自体が難しい
  2. ラベルに誤りがある

各サンプルは,以下の 2 つの指標に基づいてマッピングされます:

  • Confidence(確信度)
  • Variability(ばらつき)

これらの指標により,サンプルは以下の 3 つに大別されます:

  • easy-to-learn(学習しやすい)
  • ambiguous(曖昧)
  • hard-to-learn(学習しにくい)

マッピングするには,学習時に各エポックでの各サンプルの logit を保存するだけで十分です.

Dataset Cartographyの図解

(図は論文より)

Confidence とは

Confidence は,モデルが特定のサンプルの「真のラベル」をどれだけ高い確率で予測しているかを示す指標です.各サンプルの正解ラベルに対する確率の平均をエポックごとに計算します.

式は以下の通りです:

\hat{\mu}_i = \frac{1}{E} \sum_{e=1}^{E} p_{\theta^{(e)}}(y_i^* | x_i)
  • $\hat{\mu}_i$: $i$番目のサンプルの Confidence
  • $x_i$: $i$番目のサンプルのデータ
  • $y_i^*$: $i$番目のサンプルの正解ラベル
  • $E$: エポック数
  • $\theta^{(e)}$: エポック $e$ でのモデルのパラメータ
  • $p_{\theta^{(e)}}(y_i^* | x_i)$: エポック $e$ でのモデルの予測確率

式より,以下のことが分かります.

  • Confidence が高い = モデルが確信を持って予測している(容易)
  • Confidence が低い = モデルが自信を持てていない(難しい)

Variability とは

Variability は,エポックごとの予測確率のばらつきを示す指標です.これにより,モデルが一貫して同じ予測をしているか,エポックごとに異なる予測をしているかが分かります.

式は次の通りです:

\hat{\sigma}_i = \sqrt{\frac{1}{E} \sum_{e=1}^{E} \left( p_{\theta^{(e)}}(y_i^* | x_i) - \hat{\mu}_i \right)^2}
  • $\hat{\sigma}_i$: $i$番目のサンプルの Variability
  • $\hat{\mu}_i$: $i$番目のサンプルの Confidence
  • $p_{\theta^{(e)}}(y_i^* | x_i)$: エポック $e$ でのモデルの予測確率

式より,以下のことが分かります.

  • Variability が低い = モデルが一貫した予測を行っている(安定)
  • Variability が高い = モデルが不安定な予測をしている(不確実)

実装のイメージ

  • 学習時
    • そのままでは使用できませんが,雰囲気を掴むための参考になれば幸いです
# 各エポック開始時の処理 <-- 保存用の numpy 配列の準備
def on_train_epoch_start(self) -> None:
    total_batches = len(self.trainer.train_dataloader)
    batch_size = self.trainer.train_dataloader.batch_size

    self.train_ids = np.zeros((total_batches, batch_size), dtype=np.int32)
    self.train_logits = np.zeros((total_batches, batch_size, self.model.output_dim), dtype=np.float32)
    self.train_golds = np.zeros((total_batches, batch_size), dtype=np.int32)

# ミニバッチ単位の推論 <-- logit の保存
def training_step(self, batch, batch_idx) -> torch.Tensor:
    x, label, identifier = batch
    loss, output = self.forward(x, label)

    self.train_ids[batch_idx] = identifier.detach().cpu().numpy()
    self.train_logits[batch_idx] = output.detach().cpu().numpy()
    self.train_golds[batch_idx] = label.detach().cpu().numpy()

    return loss

# 各エポック終了時の処理 <-- logit をファイルへ保存
def on_train_epoch_end(self) -> None:
    td_dir = os.path.join(self.trainer.logger.log_dir, f"training_dynamics")
    if not os.path.exists(td_dir):
        os.makedirs(td_dir)

    epoch = self.trainer.current_epoch
    epoch_file_path = os.path.join(td_dir, f"dynamics_epoch_{epoch}.jsonl")

    td_df = pd.DataFrame(
        {
            "guid": self.train_ids.reshape(-1),
            f"logits_epoch_{epoch}": self.train_logits.reshape(-1, self.model.output_dim).tolist(),
            "gold": self.train_golds.reshape(-1),
        }
    )
    td_df.to_json(epoch_file_path, lines=True, orient="records")
  • 可視化時
    • それなりに長くなるため,Github (リンク) を参照ください.

論文の著者による Github (リンク) もあるため,気になる方はそちらも参照ください.

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?