概要
AutoMLでバッチ処理をするとBigQueryのデータセットに予測値を吐き出すことができる。
二値の分類モデルを作った時にサクッと色々なテストデータで指標を確認したかったのでUDF化しました。
データの準備
ちなみに、AutoMLバッチ予測した結果をBigQueryに出力すると以下のような感じになります。
predicted_XXXX.tables.score | predicted_XXXX.tables.value |
---|
それを実際の判定とスコアのテーブルに加工して用意します。
行 | actual | score |
---|---|---|
1 | false | 0.47 |
2 | false | 0.27 |
3 | true | 0.68 |
4 | true | 0.93 |
5 | false | 0.71 |
とりあえず完成系
CREATE TEMP FUNCTION THRESHOLDS(score FLOAT64, c INT64) AS ((
SELECT ARRAY_AGG(STRUCT(score >= i/c AS predict, i/c AS threshold))
FROM UNNEST(GENERATE_ARRAY(1, c-1)) AS i
));
CREATE TEMP FUNCTION INSPECT(actual BOOL, score FLOAT64, c INT64) AS ((
SELECT ARRAY_AGG(STRUCT(
ROUND(threshold, 2) AS threshold,
actual,
predict,
CASE
WHEN actual AND predict THEN "TP"
WHEN NOT actual AND predict THEN "FP"
WHEN actual AND NOT predict THEN "FN"
WHEN NOT actual AND NOT predict THEN "TN"
END AS class
))
FROM UNNEST(THRESHOLDS(score, c))
));
WITH inspect AS (
SELECT score, INSPECT(actual, score, 5) AS ins
FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
(false, 0.47),
(false, 0.27),
(true, 0.68),
(true, 0.93),
(false, 0.71)
])
)
SELECT
threshold,
COUNTIF(class = "TP") AS TP,
COUNTIF(class = "TN") AS TN,
COUNTIF(class = "FP") AS FP,
COUNTIF(class = "FN") AS FN,
FROM inspect, UNNEST(ins)
GROUP BY threshold
以上を実行すると
行 | threshold | TP | TN | FP | FN |
---|---|---|---|---|---|
1 | 0.2 | 2 | 0 | 3 | 0 |
2 | 0.4 | 2 | 1 | 2 | 0 |
3 | 0.6 | 2 | 2 | 1 | 0 |
4 | 0.8 | 1 | 3 | 0 | 1 |
という結果が得られる。
解説
INSPECT(actual BOOL, score FLOAT64, c INT64)
引数はデータごとの実際の判定、予測スコア、閾値分割数となっている。
上記の例はINSPECT(actual, score, 5)
この部分で5
を指定しており、0~1を5等分した結果が得られるように作ってある。
WITH inspect AS (
SELECT score, INSPECT(actual, score, 5) AS ins
FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
(false, 0.47),
(false, 0.27),
(true, 0.68),
(true, 0.93),
(false, 0.71)
])
)
SELECT *
FROM inspect
inspect
の部分だけ表示すると
行 | score | ins.threshold | ins.actual | ins.predict | ins.class |
---|---|---|---|---|---|
1 | 0.47 | 0.2 | false | true | FP |
0.4 | false | true | FP | ||
0.6 | false | false | TN | ||
0.8 | false | false | TN | ||
2 | 0.27 | 0.2 | false | true | FP |
0.4 | false | false | TN | ||
... | ... | ... | ... |
と各データごとに閾値を変えた判定結果が入っている。
のでins
というARRAY
オブジェクトを展開して、閾値でGROUP BYしカウントすると
SELECT
threshold,
COUNTIF(class = "TP") AS TP,
COUNTIF(class = "TN") AS TN,
COUNTIF(class = "FP") AS FP,
COUNTIF(class = "FN") AS FN,
FROM inspect, UNNEST(ins)
GROUP BY threshold
閾値ごとの混同行列が得られる。
ちなみに、ROUND(threshold, 2)
としているのでINSPECT(actual, score, 100)
までは動く。
THRESHOLDS(score FLOAT64, c INT64)
scoreをcの数だけ判定するためだけの関数。
CREATE TEMP FUNCTION THRESHOLDS(score FLOAT64, c INT64) AS ((
SELECT ARRAY_AGG(STRUCT(score >= i/c AS predict, i/c AS threshold))
FROM UNNEST(GENERATE_ARRAY(1, c-1)) AS i
));
SELECT THRESHOLDS(score, 5) AS ts
FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
(false, 0.47),
(false, 0.27),
(true, 0.68),
(true, 0.93),
(false, 0.71)
])
行 | ts.predict | ts.threshold |
---|---|---|
1 | true | 0.2 |
true | 0.4 | |
false | 0.6 | |
false | 0.8 | |
... | ... | ... |
おまけ
最初の例だとINSPECTした結果をわざわざ集計しなければならなくてめんどくさい。
のでGROUP BYまでUDFでやってしまうバージョンも作成した。
CREATE TEMP FUNCTION THRESHOLDS(score FLOAT64, c INT64) AS ((
SELECT ARRAY_AGG(STRUCT(score >= i/c AS predict, i/c AS threshold))
FROM UNNEST(GENERATE_ARRAY(1, c-1)) AS i
));
CREATE TEMP FUNCTION INSPECT(actual BOOL, score FLOAT64, c INT64) AS ((
SELECT ARRAY_AGG(STRUCT(
ROUND(threshold, 2) AS threshold,
actual,
predict,
CASE
WHEN actual AND predict THEN "TP"
WHEN NOT actual AND predict THEN "FP"
WHEN actual AND NOT predict THEN "FN"
WHEN NOT actual AND NOT predict THEN "TN"
END AS class
))
FROM UNNEST(THRESHOLDS(score, c))
));
CREATE TEMP FUNCTION INSPECTS(datas ARRAY<STRUCT<actual BOOL, score FLOAT64>>, c INT64) AS ((
WITH inspect AS (
SELECT INSPECT(actual, score, c) AS ins
FROM UNNEST(datas)
),
confusion_matrix AS (
SELECT
threshold,
COUNTIF(class = "TP") AS TP,
COUNTIF(class = "TN") AS TN,
COUNTIF(class = "FP") AS FP,
COUNTIF(class = "FN") AS FN
FROM inspect, UNNEST(ins)
GROUP BY threshold
)
SELECT ARRAY_AGG(STRUCT(threshold, TP, TN, FP, FN)) FROM confusion_matrix
));
WITH data AS (
SELECT *
FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
(false, 0.47),
(false, 0.27),
(true, 0.68),
(true, 0.93),
(false, 0.71)
])
)
SELECT *
FROM UNNEST(INSPECTS((SELECT ARRAY_AGG(STRUCT(actual, score)) AS datas FROM data), 5))
INSPECTS((SELECT ARRAY_AGG(STRUCT(actual, score)) AS datas FROM data), 5)
の第一引数にactual, scoreの構造体の配列自体を渡してしまい、関数内でGROUP BYしている。
注意点として、こちらはテーブル全てを構造体の配列にしてしまうことに等しいのでデータ数が多すぎると機能しないかもしれない。
おまけ2
混同行列だけみてても仕方ないのでprecisionなどを計算する関数も作った
CREATE TEMP FUNCTION INDICATOR(TP INT64, TN INT64, FP INT64, FN INT64) AS (STRUCT(
SAFE_DIVIDE(TP+FP, TN+FP+FN+TP) AS positive,
SAFE_DIVIDE(TN+TP, TN+FP+FN+TP) AS accuracy,
SAFE_DIVIDE(TP, TP+FP) AS precision,
SAFE_DIVIDE(TP, FN+TP) AS recall,
SAFE_DIVIDE(FP, TN+FP) AS fallout,
SAFE_DIVIDE(2 * SAFE_DIVIDE(TP, TP+FP) * SAFE_DIVIDE(TP, FN+TP), SAFE_DIVIDE(TP, TP+FP) + SAFE_DIVIDE(TP, FN+TP)) AS f1
));
-- 関数定義省略
WITH inspect AS (
SELECT score, INSPECT(actual, score, 5) AS ins
FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
(false, 0.47),
(false, 0.27),
(true, 0.68),
(true, 0.93),
(false, 0.71)
])
)
, confusion_matrix AS (
SELECT
threshold,
COUNTIF(class = "TP") AS TP,
COUNTIF(class = "TN") AS TN,
COUNTIF(class = "FP") AS FP,
COUNTIF(class = "FN") AS FN,
FROM inspect, UNNEST(ins)
GROUP BY threshold
)
SELECT INDICATOR(TP, TN, FP, FN).*
FROM confusion_matrix
実行結果
行 | positive | accuracy | precision | recall | fallout | f1 |
---|---|---|---|---|---|---|
1 | 1.0 | 0.4 | 0.4 | 1.0 | 1.0 | 0.5714285714285715 |
2 | 0.8 | 0.6 | 0.5 | 1.0 | 0.6666666666666666 | 0.6666666666666666 |
3 | 0.6 | 0.8 | 0.6666666666666666 | 1.0 | 0.3333333333333333 | 0.8 |
4 | 0.2 | 0.8 | 1.0 | 0.5 | 0.0 | 0.6666666666666666 |
おまけ3(AUC)
ここまできたらROCやPRのAUCを求めてしまえる。ここまでSQLでやるのかは疑問だが。
CREATE TEMP FUNCTION AUC(arr ARRAY<STRUCT<v1 FLOAT64, v2 FLOAT64>>) AS ((
-- AUCを短冊積分で計算する
SELECT SUM(v)
FROM (
SELECT v2 * (IFNULL(LEAD(v1) OVER (ORDER BY i), 1) - v1) AS v
FROM UNNEST((
SELECT ARRAY_AGG(STRUCT(v1, v2) ORDER BY v2)
FROM UNNEST(arr)
)) WITH OFFSET i
)
));
v1
に横軸、v2
に縦軸を入れるとそのAUCを計算する。ちょっとあってるか自信はない。
WITH inspect AS (
SELECT score, INSPECT(actual, score, 5) AS ins
FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
(false, 0.47),
(false, 0.27),
(true, 0.68),
(true, 0.93),
(false, 0.71)
])
)
, confusion_matrix AS (
SELECT
threshold,
COUNTIF(class = "TP") AS TP,
COUNTIF(class = "TN") AS TN,
COUNTIF(class = "FP") AS FP,
COUNTIF(class = "FN") AS FN,
FROM inspect, UNNEST(ins)
GROUP BY threshold
)
, indicat AS (
SELECT INDICATOR(TP, TN, FP, FN).*
FROM confusion_matrix
)
SELECT
AUC(ARRAY_AGG(STRUCT(fallout, recall))) AS roc_auc,
AUC(ARRAY_AGG(STRUCT(recall, precision))) AS pr_auc,
FROM indicat
行 | roc_auc | pr_auc |
---|---|---|
1 | 0.5 | 0.16666666666666669 |
データが多いならINSPECT
のc
を増やすと短冊が増えるのでより正確になると思われる。