LoginSignup
1
0

More than 3 years have passed since last update.

PostgreSQLでjaccard係数を求める

Last updated at Posted at 2019-06-16

経緯

SQLでやろうというのはどうなんだろうと思いながら、
3億ループぐらいの計算させようとPython書いてたら思いの外
パフォーマンスがでなかったので、試しにDBだけで完結させてみたくて
やってみた結果、こっちのほうが早かったのと、個人的に使い所はあるのでメモ。

ストアド

DROP FUNCTION IF EXISTS jaccard(TEXT[], TEXT[]);

-- jaccard係数を求める関数を作成
CREATE OR REPLACE FUNCTION jaccard(source TEXT[], dest TEXT[])
RETURNS NUMERIC AS $$
  DECLARE
    item_count  INTEGER; -- 係数を求める要素の総数(重複排除後の要素全数)
    match_count INTEGER; -- 両者の要素から一致する要素の個数を格納する
  BEGIN
    match_count := 0;
    item_count := 0;

    -- 重複排除後の要素の個数合計を求める
    SELECT COUNT(DISTINCT t) INTO item_count FROM UNNEST(source || dest) AS t;

    -- 一致する要素の個数を算出
    SELECT COUNT(DISTINCT t) INTO match_count
    FROM UNNEST(source) AS t
    WHERE t = ANY(dest);

    -- Jaccard係数を算出
    IF item_count = 0 OR item_count IS NULL THEN
      RETURN 0.0::NUMERIC;
    ELSE
      RETURN (match_count::NUMERIC / item_count::NUMERIC)::NUMERIC;
    END IF;
  END;
$$ LANGUAGE plpgsql;
COMMENT ON FUNCTION jaccard IS '同一要素数の2つのTEXT[]を受け取り、jaccard係数を算出する';

デモデータの作成

下記のテーブル(answers)は好きなプログラム言語の回答データが格納されているものとしています。(こんな感じ)

id user_name answer
1 user_1 Ruby
1 user_1 Python
1 user_2 C#
1 user_3 Ruby
CREATE TABLE answers (
  id SERIAL NOT NULL PRIMARY KEY,
  user_name TEXT NOT NULL,
  answer TEXT NOT NULL
);
COMMENT ON TABLE answers IS '好きなプログラム言語を任意で答えてもらった結果が格納されているという体のテーブルです。';
COMMENT ON COLUMN answers.user_name IS '回答者名が入ります';
COMMENT ON COLUMN answers.answer IS 'C,C++,C#,Python,go,ruby,php,R,Rustのいずれかが入る。';

-- デモデータ
INSERT INTO answers (user_name, answer) VALUES('user_1','C');
INSERT INTO answers (user_name, answer) VALUES('user_1','C++');
INSERT INTO answers (user_name, answer) VALUES('user_1','C#');
INSERT INTO answers (user_name, answer) VALUES('user_2','Python');
INSERT INTO answers (user_name, answer) VALUES('user_2','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_2','C#');
INSERT INTO answers (user_name, answer) VALUES('user_2','R');
INSERT INTO answers (user_name, answer) VALUES('user_2','C');
INSERT INTO answers (user_name, answer) VALUES('user_3','Rust');
INSERT INTO answers (user_name, answer) VALUES('user_3','go');
INSERT INTO answers (user_name, answer) VALUES('user_4','php');
INSERT INTO answers (user_name, answer) VALUES('user_5','C');
INSERT INTO answers (user_name, answer) VALUES('user_5','php');
INSERT INTO answers (user_name, answer) VALUES('user_5','Python');
INSERT INTO answers (user_name, answer) VALUES('user_6','go');
INSERT INTO answers (user_name, answer) VALUES('user_6','php');
INSERT INTO answers (user_name, answer) VALUES('user_6','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_6','Python');
INSERT INTO answers (user_name, answer) VALUES('user_7','R');
INSERT INTO answers (user_name, answer) VALUES('user_7','Python');
INSERT INTO answers (user_name, answer) VALUES('user_8','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_9','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_10','go');
INSERT INTO answers (user_name, answer) VALUES('user_11','C#');
INSERT INTO answers (user_name, answer) VALUES('user_11','R');
INSERT INTO answers (user_name, answer) VALUES('user_12','Python');
INSERT INTO answers (user_name, answer) VALUES('user_12','go');
INSERT INTO answers (user_name, answer) VALUES('user_12','C#');
INSERT INTO answers (user_name, answer) VALUES('user_12','R');
INSERT INTO answers (user_name, answer) VALUES('user_13','php');
INSERT INTO answers (user_name, answer) VALUES('user_13','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_14','go');
INSERT INTO answers (user_name, answer) VALUES('user_14','Python');
INSERT INTO answers (user_name, answer) VALUES('user_14','Rust');
INSERT INTO answers (user_name, answer) VALUES('user_15','C++');
INSERT INTO answers (user_name, answer) VALUES('user_15','C');
INSERT INTO answers (user_name, answer) VALUES('user_16','C#');
INSERT INTO answers (user_name, answer) VALUES('user_17','Rust');
INSERT INTO answers (user_name, answer) VALUES('user_17','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_17','Python');
INSERT INTO answers (user_name, answer) VALUES('user_17','C#');
INSERT INTO answers (user_name, answer) VALUES('user_18','Python');
INSERT INTO answers (user_name, answer) VALUES('user_18','Rust');
INSERT INTO answers (user_name, answer) VALUES('user_19','go');
INSERT INTO answers (user_name, answer) VALUES('user_19','php');
INSERT INTO answers (user_name, answer) VALUES('user_19','R');
INSERT INTO answers (user_name, answer) VALUES('user_20','C#');

動作検証用のSQL

SELECT
  'user_1' AS source_user,
  a.user_name AS dest_user,
  JACCARD(
    ARRAY(
      SELECT answer
      FROM answers 
      WHERE 
        -- 比較元は固定しておく
        user_name = 'user_1'
    ),
    ARRAY(
      SELECT answer
      FROM answers 
      WHERE 
        -- 比較先はFROMで読み込んだuser_idを指定する
        user_name = a.user_name
    )
  ) AS score
FROM answers AS a
WHERE 
  -- 比較元をuser_1にするので、user_1以外を比較対象として取得する。
  user_name <> 'user_1'
GROUP BY
  user_name
ORDER BY
  user_name
;

おまけ(ユーザ毎に類似度の高い上位3位までのユーザ一抽出)

-- ユーザ毎の類似度上位3位(同点は同順位)までを取得
SELECT 
  source_user,
  rank,
  dest_user,
  score
FROM
(
  -- JACCARDによる類似度をユーザ毎に順位付けする
  SELECT 
    source_user,
    dest_user,
    score,
    RANK() OVER(PARTITION BY source_user ORDER BY source_user, score DESC) AS rank
  FROM
  (
    -- 比較元ユーザに対して自分自身を除くすべてのユーザとのJACCARD係数を求める
    SELECT
      users.source_user,
      users.dest_user,
      JACCARD(
        ARRAY(
          -- 比較元ユーザの回答をARRAYに変換
          SELECT answer
          FROM answers 
          WHERE 
            user_name = users.source_user
        ),
        ARRAY(
          -- 比較先ユーザの回答をARRAYに変換
          SELECT answer
          FROM answers 
          WHERE 
            user_name = users.dest_user
        )
      ) AS score
    FROM
    (
      -- ユーザ毎に自身を除くユーザが結合された仮想テーブルを取得
      SELECT
        source.user_name AS source_user,
        dest.user_name AS dest_user
      FROM
        answers AS source, answers AS dest
      WHERE
        -- 自分自身は除く
        source.user_name <> dest.user_name 
      GROUP BY
        source.user_name,
        dest.user_name
    ) AS users
  ) AS result
) AS rank_filter
WHERE
  rank <= 3
ORDER BY
  source_user,
  rank
;

やってやれないこともない。
ピアソン積率相関係数も関数化できたので
次はそちらを。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0