6
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ZOZOAdvent Calendar 2024

Day 18

BigQuery UDF nDCG@k(rel={0,1}) Recall@k Precision@k

Last updated at Posted at 2024-12-17

概要

推薦でよく使用される指標を BigQuery の UDF にしました。

nDCG@k(rel={0,1})

関連商品がランキングで与えられる場合ではなく二値(クリックした/していない)の時の nDCG@k を計算しています。

{\displaystyle rel_{i}\in \{0,1\}}
create temp function ndcg_at_k_binary_rel(rec_items array<int64>, rel_items array<int64>, k int64) returns float64 as (
  (
    with
      rel_items as (
        select item from unnest(rel_items) as item
      )
  
      , rec_items as (
        select
          item
          , position + 1 as idx
          , cast(item in (select item from rel_items) as int) as rel
        from unnest(rec_items) as item with offset as position
        where position + 1 <= k
      )
  
      , true_positive as (
        select item from rec_items
        intersect distinct
        select item from rel_items
      )
  
      , false_positive as (
        select item from rec_items
        except distinct
        select item from rel_items
      )
  
      , ideal_rec_items as (
        with
          t as (
            select
              item
              , 1 as rel
            from true_positive
            union all
            select
              item
              , 0 as rel
            from false_positive
          )
  
        select
          item
          , rank() over (order by rel desc, item asc) as idx
          , rel
        from t
      )
  
      , _dcg as (
        select sum(rel / log(idx + 1, 2)) as dcg
        from rec_items
      )
  
      , _idcg as (
        select sum(rel / log(idx + 1, 2)) as idcg
        from ideal_rec_items
      )
  
      , result as (
        select coalesce(safe_divide((select dcg from _dcg), (select idcg from _idcg)), 0) as ndcg
      )
  
    select ndcg
    from result
    where
      if((select count(item) = count(distinct item) from rel_items), true, error('Error: rel_items must be unique'))
      and if((select count(item) = count(distinct item) from rec_items), true, error('Error: rec_items must be unique'))
      and if(ndcg between 0 and 1, true, error(format('Error: ndcg must be between 0 and 1, but got %t', ndcg)))
  )
);
 
select
  -- rec_items と rel_items が完全一致する時に 1 を返すことをテスト
  ndcg_at_k_binary_rel(
    [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], 5) as ndcg_1
  -- rec_items と rel_items が一つも一致しない時に 0 を返すことをテスト
  , ndcg_at_k_binary_rel(
    [1, 2, 3], [4, 5, 6], 3) as ndcg_0
  -- rec_items と rel_items が部分的に一致する時に意図した ndcg を返すことをテスト
  -- dcg = 1/log2(2) + 0/log2(3) + 1/log2(4) = 1.5, idcg = 1/log2(2) + 1/log2(3) + 0/log2(4) = 1.63092975357, ndcg = dcg/idcg = 1.5/1.63092975357 = 0.91972078914
  , ndcg_at_k_binary_rel(
    [1, 4, 2], [1, 2, 3], 3) as ndcg_partial
  -- 上記の rel_items の順番を入れ替えても ndcg が 0.91972078914 から変化しないことをテスト
  , ndcg_at_k_binary_rel(
    [1, 4, 2], [1, 3, 2], 3) as ndcg_partial_replace
  -- k を変化させても出力が正常であることをテスト
  , ndcg_at_k_binary_rel(
    [1, 4, 2], [1, 2, 3], 1) as ndcg_k_1
  , ndcg_at_k_binary_rel(
    [1, 4, 2], [1, 2, 3], 2) as ndcg_k_2
  , ndcg_at_k_binary_rel(
    [1, 4, 2], [1, 2, 3], 5) as ndcg_k_5
  -- 関連アイテムが無い時に 0 を返すことをテスト
  , ndcg_at_k_binary_rel(
    [1, 2, 3], [], 3) as ndcg_empty_rec
  -- 推薦アイテムが無い時に 0 を返すことをテスト
  , ndcg_at_k_binary_rel(
    [], [1, 2, 3], 3) as ndcg_empty_rel
  -- 関連アイテムが null の時に 0 を返すことをテスト
  , ndcg_at_k_binary_rel(
    [1, 2, 3], null, 3) as ndcg_null_rec
  -- 推薦アイテムが null の時に 0 を返すことをテスト
  , ndcg_at_k_binary_rel(
    null, [1, 2, 3], 3) as ndcg_null_rel
   -- 大きなリストを入れた時に計算が現実的であることをテスト
  , ndcg_at_k_binary_rel(
    generate_array(1, 100000), generate_array(1, 100000), 100000) as ndcg_large
  -- -- 重複する推薦アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
  -- , ndcg_at_k_binary_rel(
  --   [1, 1, 2], [1, 2, 3], 3) as ndcg_rec_items_duplicates
  -- -- 重複する関連アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
  -- , ndcg_at_k_binary_rel(
  --   [1, 2, 3], [1, 1, 3], 3) as ndcg_rel_items_duplicates

Recall@k

create temp function recall_at_k(rec_items array<int64>, rel_items array<int64>, k int64) returns float64 as (
  (
    with
      rel_items as (
        select item from unnest(rel_items) as item
      )
 
      , rec_items as (
        select item from unnest(rec_items) as item with offset as position
        where position + 1 <= k
      )
 
      , true_positive as (
        select item from rel_items
        intersect distinct
        select item from rec_items
      )
 
      , _recall as (
        select coalesce(safe_divide(cast(count(*) as float64), cast((select count(item) from rel_items) as float64)), 0) as recall
        from true_positive
      )
 
    select recall
    from _recall
    where
      if((select count(item) = count(distinct item) from rel_items), true, error('Error: rel_items must be unique'))
      and if((select count(item) = count(distinct item) from rec_items), true, error('Error: rec_items must be unique'))
      and if(recall between 0 and 1, true, error(format('Error: recall must be between 0 and 1, but got %t', recall)))
  )
);
 
select
  -- rec_items と rel_items が完全一致する時に 1 を返すことをテスト
  recall_at_k(
    [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], 5
  ) as recall_1
  -- rec_items と rel_items が一つも一致しない時に 0 を返すことをテスト
  , recall_at_k(
    [1, 2, 3], [4, 5, 6], 3
  ) as recall_0
  -- rec_items と rel_items が部分的に一致する時に意図した recall を返すことをテスト
  -- recall = 2/3 = 0.66666666667
  , recall_at_k(
    [1, 4, 2], [1, 2, 3], 3
  ) as recall_partial
  -- 上記の rel_items の順番を入れ替えても recall が 0.66666666667 から変化しないことをテスト
  , recall_at_k(
    [1, 4, 2], [1, 3, 2], 3
  ) as recall_partial_replace
  -- k を変化させても出力が正常であることをテスト
  , recall_at_k(
    [1, 4, 2], [1, 2, 3], 1
  ) as recall_k_1
  , recall_at_k(
    [1, 4, 2], [1, 2, 3], 2
  ) as recall_k_2
  , recall_at_k(
    [1, 4, 2], [1, 2, 3], 5
  ) as recall_k_5
  -- 関連アイテムが無い時に 0 を返すことをテスト
  , recall_at_k(
    [1, 2, 3], [], 3
  ) as recall_empty_rec
  -- 推薦アイテムが無い時に 0 を返すことをテスト
  , recall_at_k(
    [], [1, 2, 3], 3
  ) as recall_empty_rel
  -- 関連アイテムが null の時に 0 を返すことをテスト
  , recall_at_k(
    [1, 2, 3], null, 3
  ) as recall_null_rec
  -- 推薦アイテムが null の時に 0 を返すことをテスト
  , recall_at_k(
    null, [1, 2, 3], 3
  ) as recall_null_rel
  -- 大きなリストを入れた時に計算が現実的であることをテスト
  , recall_at_k(
    generate_array(1, 100000), generate_array(1, 100000), 100000
  ) as recall_large
-- -- 重複する推薦アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
-- , recall_at_k(
--   [1, 1, 2], [1, 2, 3], 3) as recall_rec_items_duplicates
-- -- 重複する関連アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
-- , recall_at_k(
--   [1, 2, 3], [1, 1, 3], 3) as recall_rel_items_duplicates

Precision@k

create temp function precision_at_k(rec_items array<int64>, rel_items array<int64>, k int64) returns float64 as (
  (
    with
      rel_items as (
        select item from unnest(rel_items) as item
      )
 
      , rec_items as (
        select item from unnest(rec_items) as item with offset as position
        where position + 1 <= k
      )
 
      , true_positive as (
        select item from rel_items
        intersect distinct
        select item from rec_items
      )
 
      , _precision as (
        select coalesce(safe_divide(cast(count(*) as float64), cast((select count(item) from rec_items) as float64)), 0) as precision
        from true_positive
      )
 
    select precision
    from _precision
    where
      if((select count(item) = count(distinct item) from rel_items), true, error('Error: rel_items must be unique'))
      and if((select count(item) = count(distinct item) from rec_items), true, error('Error: rec_items must be unique'))
      and if(precision between 0 and 1, true, error(format('Error: precision must be between 0 and 1, but got %t', precision)))
  )
);
 
select
  -- rec_items と rel_items が完全一致する時に 1 を返すことをテスト
  precision_at_k(
    [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], 5) as precision_1
  -- rec_items と rel_items が一つも一致しない時に 0 を返すことをテスト
  , precision_at_k(
    [1, 2, 3], [4, 5, 6], 3) as precision_0
  -- rec_items と rel_items が部分的に一致する時に意図した precision を返すことをテスト
  -- precision = 2/3 = 0.66666666667
  , precision_at_k(
    [1, 4, 2], [1, 2, 3], 3) as precision_partial
  -- 上記の rel_items の順番を入れ替えても precision が 0.66666666667 から変化しないことをテスト
  , precision_at_k(
    [1, 4, 2], [1, 3, 2], 3) as precision_partial_replace
  -- k を変化させても出力が正常であることをテスト
  , precision_at_k(
    [1, 4, 2], [1, 2, 3], 1) as precision_k_1
  , precision_at_k(
    [1, 4, 2], [1, 2, 3], 2) as precision_k_2
  , precision_at_k(
    [1, 4, 2], [1, 2, 3], 5) as precision_k_5
  -- 関連アイテムが無い時に 0 を返すことをテスト
  , precision_at_k(
    [1, 2, 3], [], 3) as precision_empty_rec
  -- 推薦アイテムが無い時に 0 を返すことをテスト
  , precision_at_k(
    [], [1, 2, 3], 3) as precision_empty_rel
  -- 関連アイテムが null の時に 0 を返すことをテスト
  , precision_at_k(
    [1, 2, 3], null, 3) as precision_null_rec
  -- 推薦アイテムが null の時に 0 を返すことをテスト
  , precision_at_k(
    null, [1, 2, 3], 3) as precision_null_rel
  -- 大きなリストを入れた時に計算が現実的であることをテスト
  , precision_at_k(
    generate_array(1, 100000), generate_array(1, 100000), 100000) as precision_large
  -- -- 重複する推薦アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
  -- , precision_at_k(
  --   [1, 1, 2], [1, 2, 3], 3) as precision_rec_items_duplicates
  -- -- 重複する関連アイテムがある時に、エラーを返すことをテスト(エラーをテストするので、テスト時以外はコメントアウトする)
  -- , precision_at_k(
  --   [1, 2, 3], [1, 1, 3], 3) as precision_rel_items_duplicates
6
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?