LoginSignup
1
1

More than 1 year has passed since last update.

Rust研究:Precision、Recall、F値

Last updated at Posted at 2023-04-07

二値分類と多値分類の評価指標の Precision(適合率)、Recall(再現率)、F-measure(F値)の定義を再確認し、その算出を Rust で実装してみましたのでメモを残します。

お題

カメラの前を横切った動物を画像処理で検出し、それが犬・猫・豚のいずれだったかを判定するシステムがあるとします。
(このシステムの判定対象は犬・猫・豚の三つに限定され、他の動物は判定対象にならないことが保証されているものとします)

その「推測」と「実際」を観測した結果を表 1 に示します。

表1. カメラの前を横切った動物の推測と実際

観測時刻 推測した動物 実際の動物
06:34
06:54
08:36
09:22
11:45
12:08
14:44
19:17
21:30
22:49

この観測結果より、この判定システムの推測性能を評価してみたいと思います。
(実際の科学的な検証では、観測方法や標本数の適切性などから議論する必要があると思いますが、ここではそこはあえて考えないことにします)

二値分類で考えてみる

まずは評価のとっかかりとして「」の推測にしぼって考えてみます。ここでは猫と豚の推測は考えません。

カメラの前を横切ったのが犬なら「真」、横切ったのが犬ではないなら「偽」となります。

先ほどの観測結果の「推測」と「実際」はこの真か偽のいずれかに分類することができます。

このようにある事項を二つのグループに分類する作業は「二値分類」 (Binary Classification)と呼ぶそうです。

先ほどの表 1 を「カメラの前を横切った動物が犬かどうか」で書き換えると次のようになります。

表 2. カメラの前を横切った動物の推測と実際 (T=真、F=偽)

観測時刻 推測した動物 実際の動物 推測:犬 実際:犬
06:34 T T
06:54 F F
08:36 F F
09:22 T T
11:45 T F
12:08 T T
14:44 F T
19:17 F F
21:30 F F
22:49 T T

ここでそれぞれの観測結果が、True Positive (TP)、False Positive (FP)、False Negative (FN)、True Negative (TN) のいずれに該当するのかを整理します。

《参考》

推測 実際 TP/FP/FN/TN
T T True Positive(TP)
T F False Positive(FP)
F T False Negative(FN)
F F True Negative(TN)

先ほどの表にTP/FP/FN/TNを追加し、それぞれの件数を数えてみます。

表 3. 通りを犬が横切ることの予測と実際 (T=真、F=偽) および TP/FP/FN/TN の件数

観測時刻 推測した動物 実際の動物 推測:犬 実際:犬 TP FP FN TN
06:34 T T o
06:54 F F o
08:36 F F o
09:22 T T o
11:45 T F o
12:08 T T o
14:44 F T o
19:17 F F o
21:30 F F o
22:49 T T o
4 1 1 4

表の最下行の 4 つの数字は TP/FP/FN/TN のそれぞれの合計数です。

この数字を使って Accuracy/Precision/Recall/F1 を計算してみました:

Accuracy = \frac{TP+TN}{TP+FP+FN+TN} = \frac{8}{10} = 0.8 \\
Precision = \frac{TP}{TP+FP} = \frac{4}{5} = 0.8 \\
Recall = \frac{TP}{TP+FN} = \frac{4}{5} = 0.8 \\
F1 = \frac{2}{\frac{1}{Precision}+\frac{1}{Recall}} = \frac{8}{10} = 0.8 \\

このシステムが「カメラの前を横切る動物が犬であることを推測する性能」は上のように評価されました。

実装例

BinaryClassificationCounter は二値分類の事例を数えるカウンターの構造体です。
このカウンターにはその時点までの事例で出現した TP/FP/FN/TN の件数を保持するメンバー変数を持たせます。

pub const POS:bool = true; // POSITIVE (陽性)
pub const NEG:bool = false; // NEGATIVE (陰性)

use derive_getters::Getters;

// BinaryClassificationCounter は二値分類用の事例カウンター
#[derive(Debug, Getters)]
pub struct BinaryClassificationCounter {
    num_tp:i32, // True  Positive (真陽性 / 予測=陽性・実際=陽性) の件数
    num_fp:i32, // False Positive (偽陽性 / 予測=陽性・実際=陰性) の件数
    num_fn:i32, // False Negative (偽陰性 / 予測=陰性・実際=陽性) の件数
    num_tn:i32, // True  Negative (真陽性 / 予測=陰性・実際=陰性) の件数
}

impl BinaryClassificationCounter {
    pub fn new () -> Self {
        Self {
            num_tp:0,
            num_fp:0,
            num_fn:0,
            num_tn:0,
        }
    }

メソッド add は事例として与えられた推測と実際の組み合わせで TP/FP/FN/TN のいずれに該当するかを決定しその件数をカウントしていきます。

    // add は事例をカウントする。推測(predicted)と実際(actual)
    pub fn add (&mut self, predicted:bool, actual:bool) {
        match (predicted, actual) {
            (POS, POS) => { // 真陽性の事例
                self.num_tp += 1;
            },
            (POS, NEG) => { // 偽陽性の事例
                self.num_fp += 1;
            },
            (NEG, POS) => { // 偽陰性の事例
                self.num_fn += 1;
            },
            (NEG, NEG) => { // 真陰性の事例
                self.num_tn += 1;
            },
        }
    }

残りは評価指標 Accuracy / Precision / Recall / F値を計算する部分です。これらは場合によっては算出することができないことがあることから、戻り値は Option<f64> としました。

    // accuracy は「正確度」を算出する
    pub fn accuracy (&self) -> Option<f64> {
        let total = self.total();
        if total < 1 {
            return None
        }

        // 正解の件数 / 全件数
        Some(
            f64::from(self.num_tp + self.num_tn) / f64::from(total)
        )
    }

    // precision は「適合率」を算出する
    pub fn precision (&self) -> Option<f64> {
        if self.num_tp + self.num_fp < 1 {
            return None
        }

        // 陽性と予測された事例のうち実際に陽性であるものの割合
        // =真陽性の件数 / (真陽性の件数+偽陽性の件数)
        Some(
            f64::from(self.num_tp) / f64::from(self.num_tp + self.num_fp)
        )
    }

    // recall は「再現率」を算出する
    pub fn recall (&self) -> Option<f64> {
        if self.num_tp + self.num_fn < 1 {
            return None
        }
        // 実際に陽性である事例のうち陽性と予測されたものの割合
        // =真陽性の件数 / (真陽性の件数+偽陰性の件数)
        Some(
            f64::from(self.num_tp) / f64::from(self.num_tp + self.num_fn)
        )
    }

    // f1 は F-measure を算出する
    pub fn f1 (&self) -> Option<f64> {
        if let Some(precision) = self.precision() {
            if let Some(recall) = self.recall() {
                if precision + recall == 0.0 {
                    return None;
                }
                // 適合率と再現率の調和平均
                return Some(
                    2.0 * precision * recall / (precision + recall)
                )
            }
        }

        None
    }

    // total はすべての事例の件数を算出する
    fn total (&self) -> i32 {
        self.num_tp + self.num_fp + self.num_fn + self.num_tn
    }

    // print は評価指標を表示する
    pub fn print(&self) {
        print_option_value("Accuracy :", self.accuracy());
        print_option_value("Precision:", self.precision());
        print_option_value("Recall   :", self.recall());
        print_option_value("F-measure:", self.f1());
    }
}

fn print_option_value(msg:&str, option_value:Option<f64>) {
    print!("{} ", msg);
    if let Some(p) = option_value {
        println!("{:.2}", p);
    } else {
        println!("NaN");
    }
}

実行例

先ほどの観測結果を用いた実行例です。

fn process1() {
    let mut b = BinaryClassificationCounter::new();

    // 以下、事例を追加
    b.add(POS, POS);
    b.add(NEG, NEG);
    b.add(NEG, NEG);
    b.add(POS, POS);
    b.add(POS, NEG);
    b.add(POS, POS);
    b.add(NEG, POS);
    b.add(NEG, NEG);
    b.add(NEG, NEG);
    b.add(POS, POS);

    // 評価指標の表示
    b.print();
}

// 実行結果
Accuracy : 0.80
Precision: 0.80
Recall   : 0.80
F-measure: 0.80

多値分類で考える

先ほどは「犬」だけにしぼって考えてみました。
今度は犬だけでなく猫や豚への分類についても考えます。
このように3つ以上のグループに分類する作業を「多値分類」(Multi Classification) と呼ぶそうです。

表2 の一行目の観測結果をみてみます。

観測時刻 推測した動物 実際の動物 推測:犬 実際:犬
06:34 T T

犬と推測して実際も犬だった、という観測結果です。

このデータを猫や豚の推測/実際も考慮すると次のようになります。

観測時刻 推測した動物 実際の動物 推測:犬 実際:犬 推測:猫 実際:猫 推測:豚 実際:豚
06:34 T T F F F F

表の残りを埋めてみましょう。

表 4. 犬・猫・豚の推測と実際

観測時刻 推測した動物 実際の動物 推測:犬 実際:犬 推測:猫 実際:猫 推測:豚 実際:豚
06:34 T T F F F F
06:54 F F T T F F
08:36 F F F T T F
09:22 T T F F F F
11:45 T F F T F F
12:08 T T F F F F
14:44 F T T F F F
19:17 F F T T F F
21:30 F F F F T T
22:49 T T F F F F

詳細は省略しますが、犬・豚・猫の TP/FP/FN/TN の件数を数えてみました。

表 5. 犬・猫・豚の TP/FP/FN/TN の件数

TP 4 2 1
FP 1 1 1
FN 1 2 0
TN 4 5 8

猫の推測に関する Accuracy / Precision / Recall / F値を計算してみました。

Accuracy = \frac{TP+TN}{TP+FP+FN+TN} = \frac{7}{10} = 0.7 \\
Precision = \frac{TP}{TP+FP} = \frac{2}{3} = 0.67 \\
Recall = \frac{TP}{TP+FN} = \frac{2}{4} = 0.5 \\
F1 = \frac{2}{\frac{1}{Precision}+\frac{1}{Recall}} = 0.57 \\

同様に豚の推測に関する Accuracy / Precision / Recall / F値を計算してみました。

Accuracy = \frac{TP+TN}{TP+FP+FN+TN} = \frac{9}{10} = 0.9 \\
Precision = \frac{TP}{TP+FP} = \frac{1}{2} = 0.5 \\
Recall = \frac{TP}{TP+FN} = \frac{1}{1} = 1.0 \\
F1 = \frac{2}{\frac{1}{Precision}+\frac{1}{Recall}} = 0.67 \\

全体的な評価にはマクロ平均 (Macro mean) とマイクロ平均 (Micro mean) の二つの方法があります。

  • マクロ平均はそれぞれのクラス(今回は犬・猫・豚の三つ)の Precision と Recall の平均をとる。
  • マイクロ平均はすべてのクラスの TP/FP/FN/TN の件数をそれぞれ足し合わせて Precision と Recall を求める。

手計算が面倒になってきたのでここからさきの計算はプログラムにやらせます。

実装例

MultiClassificationCounter は多値分類用の事例カウンターの構造体です。
メンバ変数に想定するクラス数と各クラスの BinaryClassificationCounter のインスタンスのベクタを保持します。

// MultiClassificationCounter は多クラス分類用の事例カウンター
pub struct MultiClassificationCounter {
    num_class: usize, // 想定するクラス数
    counters: Vec<BinaryClassificationCounter>, // 各クラスのカウンター
}

impl MultiClassificationCounter {
    pub fn new (num_class: usize) -> Self {
        let mut counters: Vec<BinaryClassificationCounter> = vec![];
        let mut i = 0;
        while i < num_class {
            counters.push(BinaryClassificationCounter::new());
            i += 1;
        }

        Self {
            num_class: num_class,
            counters: counters,
        }
    }

メソッド add は事例をカウントします。

BinaryClassificationCounter の add では推測と実際の真偽値ペアを事例として与えていましたが、ここでは推測したクラスIDと実際のクラスIDのペアを事例として与えます。

推測したクラスIDと実際のクラスIDに基づいて、それぞれのクラスのカウンタに事例を与えていきます。
詳細は省きますが、先ほどの表 4. における犬・猫・豚の推測と実際を抽出していく作業に対応しています。

    // add は事例をカウントする。推測したクラスIDと実際のクラスID 
    pub fn add (&mut self, predicted_class_id:usize, actual_class_id:usize) {
        // クラスIDのチェック
        if predicted_class_id >= self.num_class || actual_class_id >= self.num_class {
            return
        }
        let preds = self.conv(predicted_class_id);
        let acts = self.conv(actual_class_id);
        for (class_id, (predicted, actual)) in preds.iter().zip(acts.iter()).enumerate() {
            self.add_per_class_id(class_id, *predicted, *actual);
        }
    }

    // add_per_class_id はクラス単位で事例をカウントする。 
    fn add_per_class_id (&mut self, class_id:usize, predicted:bool, actual:bool) {
        self.counters[class_id].add(predicted, actual)
    }

    // conv は class_id を One-Hot なベクタに変換する
    // 説明しにくいので以下に例を示す
    // ・ class_id = 1, num_class = 4 のとき返るベクタ:
    //    [false, true, false, false]
    // ・ class_id = 2, num_class = 3 のとき返るベクタ:
    //    [false, false, true]
    fn conv(&self, class_id:usize) -> Vec<bool> {
        let mut r:Vec<bool> = vec![];
        let mut i:usize = 0;
        while i<self.num_class {
            r.push(if i == class_id {
                POS
            } else {
                NEG
            });
            i += 1;
        }
        r
    }

Accuracy はその定義に従ってすべてのクラスの TP/FP/FN/TN の合計件数から算出します。

    pub fn accuracy(&self) -> Option<f64> {
        let total = self.total();
        if total < 1 {
            return None
        }

        Some (
            f64::from(self.total_tp() + self.total_tn()) / 
                f64::from(total)
        )
    }

    pub fn total_tp (&self) -> i32 {
        self.counters.iter().map(|m|m.num_tp()).sum()
    }
    pub fn total_fp (&self) -> i32 {
        self.counters.iter().map(|m|m.num_fp()).sum()
    }
    pub fn total_fn (&self) -> i32 {
        self.counters.iter().map(|m|m.num_fn()).sum()
    }
    pub fn total_tn (&self) -> i32 {
        self.counters.iter().map(|m|m.num_tn()).sum()
    }

    // total はすべての事例の件数を求める
    pub fn total (&self) -> i32 {
        // 各クラスの total を求め合計を求める
        self.counters.iter().map(|m|m.total()).sum()
    }

マクロ平均の Precision / Recall そして F 値の算出です。
F 値はマクロ平均の Precision と Recall から求めます。

    pub fn macro_precision(&self) -> Option<f64> {
        if self.num_class == 0 {
            return None
        }

        let mut sum:f64 = 0.0;
        for c in self.counters.iter() {
            if let Some(p) = c.precision() {
                sum += p;
            } else {
                return None
            }
        }
        Some (
            f64::from(sum) / f64::from(self.num_class as i32)
        )
    }

    pub fn macro_recall(&self) -> Option<f64> {
        if self.num_class == 0 {
            return None
        }

        let mut sum:f64 = 0.0;
        for c in self.counters.iter() {
            if let Some(p) = c.recall() {
                sum += p;
            } else {
                return None
            }
        }
        Some (
            f64::from(sum) / f64::from(self.num_class as i32)
        )
    }

    pub fn macro_f1 (&self) -> Option<f64> {
        if let Some(precision) = self.macro_precision() {
            if let Some(recall) = self.macro_recall() {
                if precision + recall == 0.0 {
                    return None;
                }
                // 適合率と再現率の調和平均
                return Some(
                    2.0 * precision * recall / (precision + recall)
                )
            }
        }

        None
    }

マイクロ平均の Precision / Recall そして F 値の算出です。
F 値はマイクロ平均の Precision と Recall から求めます。

    pub fn micro_precision(&self) -> Option<f64> {
        let total_tp = self.total_tp();
        let total_tp_fp = total_tp + self.total_fp();
        if total_tp_fp < 1 {
            return None
        }

        Some (
            f64::from(total_tp) / f64::from(total_tp_fp)
        )
    }

    pub fn micro_recall(&self) -> Option<f64> {
        let total_tp = self.total_tp();
        let total_tp_fn = total_tp + self.total_fn();
        if total_tp_fn < 1 {
            return None
        }

        Some (
            f64::from(total_tp) / f64::from(total_tp_fn)
        )
    }

    pub fn micro_f1(&self) -> Option<f64> {
        if let Some(precision) = self.micro_precision() {
            if let Some(recall) = self.micro_recall() {
                if precision + recall == 0.0 {
                    return None;
                }
                // 適合率と再現率の調和平均
                return Some(
                    2.0 * precision * recall / (precision + recall)
                )
            }
        }

        None
    }

    pub fn print(&self) {
        print_option_value("Accuracy       :", self.accuracy());
        print_option_value("Macro Precision:", self.macro_precision());
        print_option_value("Macro Recall   :", self.macro_recall());
        print_option_value("Macro F-measure:", self.macro_f1());
        print_option_value("Micro Precision:", self.micro_precision());
        print_option_value("Micro Recall   :", self.micro_recall());
        print_option_value("Micro F-measure:", self.micro_f1());

        self.counters.iter().enumerate().for_each(|(class_id, m)|{
            println!("------------------");
            println!("class_id = {}", class_id);
            m.print()
        })
        
    }
}

実行例

先ほどの観測結果を使った実行例です。

const DOG:usize = 0;
const CAT:usize = 1;
const PIG:usize = 2;

fn process2() {
    // 犬・猫・豚の3クラスの多値分類のカウンター
    let mut m = MultiClassificationCounter::new(3);

    m.add(DOG, DOG);
    m.add(CAT, CAT);
    m.add(PIG, CAT);
    m.add(DOG, DOG);
    m.add(DOG, CAT);
    m.add(DOG, DOG);
    m.add(CAT, DOG);
    m.add(CAT, CAT);
    m.add(PIG, PIG);
    m.add(DOG, DOG);

    // 評価指標の表示
    m.print();
}

// 実行結果
Accuracy       : 0.80
Macro Precision: 0.66
Macro Recall   : 0.77
Macro F-measure: 0.71
Micro Precision: 0.70
Micro Recall   : 0.70
Micro F-measure: 0.70
------------------
class_id = 0
Accuracy : 0.80
Precision: 0.80
Recall   : 0.80
F-measure: 0.80
------------------
class_id = 1
Accuracy : 0.70
Precision: 0.67
Recall   : 0.50
F-measure: 0.57
------------------
class_id = 2
Accuracy : 0.90
Precision: 0.50
Recall   : 1.00
F-measure: 0.67

実行結果の class_id = 0 は犬、class_id = 1 は猫、class_id = 2 は豚に対応します。

以上、二値分類と多値分類における評価指標を計算する Rust 実装について考えてみました。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1