概要
推薦でよく使用される指標を BigQuery の UDF にしました。
nDCG@k(rel={0,1})
関連商品がランキングで与えられる場合ではなく二値(クリックした/していない)の時の nDCG@k を計算しています。
{\displaystyle rel_{i}\in \{0,1\}}
create temp function ndcg_at_k_binary_rel(rec_items array<int64>, rel_items array<int64>, k int64) returns float64 as (
(
with
rel_items as (
select item from unnest(rel_items) as item
)
, rec_items as (
select
item
, position + 1 as idx
, cast(item in (select item from rel_items) as int) as rel
from unnest(rec_items) as item with offset as position
where position + 1 <= k
)
, true_positive as (
select item from rec_items
intersect distinct
select item from rel_items
)
, false_positive as (
select item from rec_items
except distinct
select item from rel_items
)
, ideal_rec_items as (
with
t as (
select
item
, 1 as rel
from true_positive
union all
select
item
, 0 as rel
from false_positive
)
select
item
, rank() over (order by rel desc, item asc) as idx
, rel
from t
)
, _dcg as (
select sum(rel / log(idx + 1, 2)) as dcg
from rec_items
)
, _idcg as (
select sum(rel / log(idx + 1, 2)) as idcg
from ideal_rec_items
)
, result as (
select coalesce(safe_divide((select dcg from _dcg), (select idcg from _idcg)), 0) as ndcg
)
select ndcg
from result
where
if((select count(item) = count(distinct item) from rel_items), true, error('Error: rel_items must be unique'))
and if((select count(item) = count(distinct item) from rec_items), true, error('Error: rec_items must be unique'))
and if(ndcg between 0 and 1, true, error(format('Error: ndcg must be between 0 and 1, but got %t', ndcg)))
)
);
select
-- rec_items と rel_items が完全一致する時に 1 を返すことをテスト
ndcg_at_k_binary_rel(
[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], 5) as ndcg_1
-- rec_items と rel_items が一つも一致しない時に 0 を返すことをテスト
, ndcg_at_k_binary_rel(
[1, 2, 3], [4, 5, 6], 3) as ndcg_0
-- rec_items と rel_items が部分的に一致する時に意図した ndcg を返すことをテスト
-- dcg = 1/log2(2) + 0/log2(3) + 1/log2(4) = 1.5, idcg = 1/log2(2) + 1/log2(3) + 0/log2(4) = 1.63092975357, ndcg = dcg/idcg = 1.5/1.63092975357 = 0.91972078914
, ndcg_at_k_binary_rel(
[1, 4, 2], [1, 2, 3], 3) as ndcg_partial
-- 上記の rel_items の順番を入れ替えても ndcg が 0.91972078914 から変化しないことをテスト
, ndcg_at_k_binary_rel(
[1, 4, 2], [1, 3, 2], 3) as ndcg_partial_replace
-- k を変化させても出力が正常であることをテスト
, ndcg_at_k_binary_rel(
[1, 4, 2], [1, 2, 3], 1) as ndcg_k_1
, ndcg_at_k_binary_rel(
[1, 4, 2], [1, 2, 3], 2) as ndcg_k_2
, ndcg_at_k_binary_rel(
[1, 4, 2], [1, 2, 3], 5) as ndcg_k_5
-- 関連アイテムが無い時に 0 を返すことをテスト
, ndcg_at_k_binary_rel(
[1, 2, 3], [], 3) as ndcg_empty_rec
-- 推薦アイテムが無い時に 0 を返すことをテスト
, ndcg_at_k_binary_rel(
[], [1, 2, 3], 3) as ndcg_empty_rel
-- 関連アイテムが null の時に 0 を返すことをテスト
, ndcg_at_k_binary_rel(
[1, 2, 3], null, 3) as ndcg_null_rec
-- 推薦アイテムが null の時に 0 を返すことをテスト
, ndcg_at_k_binary_rel(
null, [1, 2, 3], 3) as ndcg_null_rel
-- 大きなリストを入れた時に計算が現実的であることをテスト
, ndcg_at_k_binary_rel(
generate_array(1, 100000), generate_array(1, 100000), 100000) as ndcg_large
-- -- 重複する推薦アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
-- , ndcg_at_k_binary_rel(
-- [1, 1, 2], [1, 2, 3], 3) as ndcg_rec_items_duplicates
-- -- 重複する関連アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
-- , ndcg_at_k_binary_rel(
-- [1, 2, 3], [1, 1, 3], 3) as ndcg_rel_items_duplicates
Recall@k
create temp function recall_at_k(rec_items array<int64>, rel_items array<int64>, k int64) returns float64 as (
(
with
rel_items as (
select item from unnest(rel_items) as item
)
, rec_items as (
select item from unnest(rec_items) as item with offset as position
where position + 1 <= k
)
, true_positive as (
select item from rel_items
intersect distinct
select item from rec_items
)
, _recall as (
select coalesce(safe_divide(cast(count(*) as float64), cast((select count(item) from rel_items) as float64)), 0) as recall
from true_positive
)
select recall
from _recall
where
if((select count(item) = count(distinct item) from rel_items), true, error('Error: rel_items must be unique'))
and if((select count(item) = count(distinct item) from rec_items), true, error('Error: rec_items must be unique'))
and if(recall between 0 and 1, true, error(format('Error: recall must be between 0 and 1, but got %t', recall)))
)
);
select
-- rec_items と rel_items が完全一致する時に 1 を返すことをテスト
recall_at_k(
[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], 5
) as recall_1
-- rec_items と rel_items が一つも一致しない時に 0 を返すことをテスト
, recall_at_k(
[1, 2, 3], [4, 5, 6], 3
) as recall_0
-- rec_items と rel_items が部分的に一致する時に意図した recall を返すことをテスト
-- recall = 2/3 = 0.66666666667
, recall_at_k(
[1, 4, 2], [1, 2, 3], 3
) as recall_partial
-- 上記の rel_items の順番を入れ替えても recall が 0.66666666667 から変化しないことをテスト
, recall_at_k(
[1, 4, 2], [1, 3, 2], 3
) as recall_partial_replace
-- k を変化させても出力が正常であることをテスト
, recall_at_k(
[1, 4, 2], [1, 2, 3], 1
) as recall_k_1
, recall_at_k(
[1, 4, 2], [1, 2, 3], 2
) as recall_k_2
, recall_at_k(
[1, 4, 2], [1, 2, 3], 5
) as recall_k_5
-- 関連アイテムが無い時に 0 を返すことをテスト
, recall_at_k(
[1, 2, 3], [], 3
) as recall_empty_rec
-- 推薦アイテムが無い時に 0 を返すことをテスト
, recall_at_k(
[], [1, 2, 3], 3
) as recall_empty_rel
-- 関連アイテムが null の時に 0 を返すことをテスト
, recall_at_k(
[1, 2, 3], null, 3
) as recall_null_rec
-- 推薦アイテムが null の時に 0 を返すことをテスト
, recall_at_k(
null, [1, 2, 3], 3
) as recall_null_rel
-- 大きなリストを入れた時に計算が現実的であることをテスト
, recall_at_k(
generate_array(1, 100000), generate_array(1, 100000), 100000
) as recall_large
-- -- 重複する推薦アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
-- , recall_at_k(
-- [1, 1, 2], [1, 2, 3], 3) as recall_rec_items_duplicates
-- -- 重複する関連アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
-- , recall_at_k(
-- [1, 2, 3], [1, 1, 3], 3) as recall_rel_items_duplicates
Precision@k
create temp function precision_at_k(rec_items array<int64>, rel_items array<int64>, k int64) returns float64 as (
(
with
rel_items as (
select item from unnest(rel_items) as item
)
, rec_items as (
select item from unnest(rec_items) as item with offset as position
where position + 1 <= k
)
, true_positive as (
select item from rel_items
intersect distinct
select item from rec_items
)
, _precision as (
select coalesce(safe_divide(cast(count(*) as float64), cast((select count(item) from rec_items) as float64)), 0) as precision
from true_positive
)
select precision
from _precision
where
if((select count(item) = count(distinct item) from rel_items), true, error('Error: rel_items must be unique'))
and if((select count(item) = count(distinct item) from rec_items), true, error('Error: rec_items must be unique'))
and if(precision between 0 and 1, true, error(format('Error: precision must be between 0 and 1, but got %t', precision)))
)
);
select
-- rec_items と rel_items が完全一致する時に 1 を返すことをテスト
precision_at_k(
[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], 5) as precision_1
-- rec_items と rel_items が一つも一致しない時に 0 を返すことをテスト
, precision_at_k(
[1, 2, 3], [4, 5, 6], 3) as precision_0
-- rec_items と rel_items が部分的に一致する時に意図した precision を返すことをテスト
-- precision = 2/3 = 0.66666666667
, precision_at_k(
[1, 4, 2], [1, 2, 3], 3) as precision_partial
-- 上記の rel_items の順番を入れ替えても precision が 0.66666666667 から変化しないことをテスト
, precision_at_k(
[1, 4, 2], [1, 3, 2], 3) as precision_partial_replace
-- k を変化させても出力が正常であることをテスト
, precision_at_k(
[1, 4, 2], [1, 2, 3], 1) as precision_k_1
, precision_at_k(
[1, 4, 2], [1, 2, 3], 2) as precision_k_2
, precision_at_k(
[1, 4, 2], [1, 2, 3], 5) as precision_k_5
-- 関連アイテムが無い時に 0 を返すことをテスト
, precision_at_k(
[1, 2, 3], [], 3) as precision_empty_rec
-- 推薦アイテムが無い時に 0 を返すことをテスト
, precision_at_k(
[], [1, 2, 3], 3) as precision_empty_rel
-- 関連アイテムが null の時に 0 を返すことをテスト
, precision_at_k(
[1, 2, 3], null, 3) as precision_null_rec
-- 推薦アイテムが null の時に 0 を返すことをテスト
, precision_at_k(
null, [1, 2, 3], 3) as precision_null_rel
-- 大きなリストを入れた時に計算が現実的であることをテスト
, precision_at_k(
generate_array(1, 100000), generate_array(1, 100000), 100000) as precision_large
-- -- 重複する推薦アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
-- , precision_at_k(
-- [1, 1, 2], [1, 2, 3], 3) as precision_rec_items_duplicates
-- -- 重複する関連アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
-- , precision_at_k(
-- [1, 2, 3], [1, 1, 3], 3) as precision_rel_items_duplicates