経緯
SQLでやろうというのはどうなんだろうと思いながら、
3億ループぐらいの計算させようとPython書いてたら思いの外
パフォーマンスがでなかったので、試しにDBだけで完結させてみたくて
やってみた結果、こっちのほうが早かったのと、個人的に使い所はあるのでメモ。
ストアド
DROP FUNCTION IF EXISTS jaccard(TEXT[], TEXT[]);
-- jaccard係数を求める関数を作成
CREATE OR REPLACE FUNCTION jaccard(source TEXT[], dest TEXT[])
RETURNS NUMERIC AS $$
DECLARE
item_count INTEGER; -- 係数を求める要素の総数(重複排除後の要素全数)
match_count INTEGER; -- 両者の要素から一致する要素の個数を格納する
BEGIN
match_count := 0;
item_count := 0;
-- 重複排除後の要素の個数合計を求める
SELECT COUNT(DISTINCT t) INTO item_count FROM UNNEST(source || dest) AS t;
-- 一致する要素の個数を算出
SELECT COUNT(DISTINCT t) INTO match_count
FROM UNNEST(source) AS t
WHERE t = ANY(dest);
-- Jaccard係数を算出
IF item_count = 0 OR item_count IS NULL THEN
RETURN 0.0::NUMERIC;
ELSE
RETURN (match_count::NUMERIC / item_count::NUMERIC)::NUMERIC;
END IF;
END;
$$ LANGUAGE plpgsql;
COMMENT ON FUNCTION jaccard IS '同一要素数の2つのTEXT[]を受け取り、jaccard係数を算出する';
デモデータの作成
下記のテーブル(answers)は好きなプログラム言語の回答データが格納されているものとしています。(こんな感じ)
id | user_name | answer |
---|---|---|
1 | user_1 | Ruby |
1 | user_1 | Python |
1 | user_2 | C# |
1 | user_3 | Ruby |
CREATE TABLE answers (
id SERIAL NOT NULL PRIMARY KEY,
user_name TEXT NOT NULL,
answer TEXT NOT NULL
);
COMMENT ON TABLE answers IS '好きなプログラム言語を任意で答えてもらった結果が格納されているという体のテーブルです。';
COMMENT ON COLUMN answers.user_name IS '回答者名が入ります';
COMMENT ON COLUMN answers.answer IS 'C,C++,C#,Python,go,ruby,php,R,Rustのいずれかが入る。';
-- デモデータ
INSERT INTO answers (user_name, answer) VALUES('user_1','C');
INSERT INTO answers (user_name, answer) VALUES('user_1','C++');
INSERT INTO answers (user_name, answer) VALUES('user_1','C#');
INSERT INTO answers (user_name, answer) VALUES('user_2','Python');
INSERT INTO answers (user_name, answer) VALUES('user_2','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_2','C#');
INSERT INTO answers (user_name, answer) VALUES('user_2','R');
INSERT INTO answers (user_name, answer) VALUES('user_2','C');
INSERT INTO answers (user_name, answer) VALUES('user_3','Rust');
INSERT INTO answers (user_name, answer) VALUES('user_3','go');
INSERT INTO answers (user_name, answer) VALUES('user_4','php');
INSERT INTO answers (user_name, answer) VALUES('user_5','C');
INSERT INTO answers (user_name, answer) VALUES('user_5','php');
INSERT INTO answers (user_name, answer) VALUES('user_5','Python');
INSERT INTO answers (user_name, answer) VALUES('user_6','go');
INSERT INTO answers (user_name, answer) VALUES('user_6','php');
INSERT INTO answers (user_name, answer) VALUES('user_6','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_6','Python');
INSERT INTO answers (user_name, answer) VALUES('user_7','R');
INSERT INTO answers (user_name, answer) VALUES('user_7','Python');
INSERT INTO answers (user_name, answer) VALUES('user_8','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_9','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_10','go');
INSERT INTO answers (user_name, answer) VALUES('user_11','C#');
INSERT INTO answers (user_name, answer) VALUES('user_11','R');
INSERT INTO answers (user_name, answer) VALUES('user_12','Python');
INSERT INTO answers (user_name, answer) VALUES('user_12','go');
INSERT INTO answers (user_name, answer) VALUES('user_12','C#');
INSERT INTO answers (user_name, answer) VALUES('user_12','R');
INSERT INTO answers (user_name, answer) VALUES('user_13','php');
INSERT INTO answers (user_name, answer) VALUES('user_13','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_14','go');
INSERT INTO answers (user_name, answer) VALUES('user_14','Python');
INSERT INTO answers (user_name, answer) VALUES('user_14','Rust');
INSERT INTO answers (user_name, answer) VALUES('user_15','C++');
INSERT INTO answers (user_name, answer) VALUES('user_15','C');
INSERT INTO answers (user_name, answer) VALUES('user_16','C#');
INSERT INTO answers (user_name, answer) VALUES('user_17','Rust');
INSERT INTO answers (user_name, answer) VALUES('user_17','ruby');
INSERT INTO answers (user_name, answer) VALUES('user_17','Python');
INSERT INTO answers (user_name, answer) VALUES('user_17','C#');
INSERT INTO answers (user_name, answer) VALUES('user_18','Python');
INSERT INTO answers (user_name, answer) VALUES('user_18','Rust');
INSERT INTO answers (user_name, answer) VALUES('user_19','go');
INSERT INTO answers (user_name, answer) VALUES('user_19','php');
INSERT INTO answers (user_name, answer) VALUES('user_19','R');
INSERT INTO answers (user_name, answer) VALUES('user_20','C#');
動作検証用のSQL
SELECT
'user_1' AS source_user,
a.user_name AS dest_user,
JACCARD(
ARRAY(
SELECT answer
FROM answers
WHERE
-- 比較元は固定しておく
user_name = 'user_1'
),
ARRAY(
SELECT answer
FROM answers
WHERE
-- 比較先はFROMで読み込んだuser_idを指定する
user_name = a.user_name
)
) AS score
FROM answers AS a
WHERE
-- 比較元をuser_1にするので、user_1以外を比較対象として取得する。
user_name <> 'user_1'
GROUP BY
user_name
ORDER BY
user_name
;
おまけ(ユーザ毎に類似度の高い上位3位までのユーザ一抽出)
-- ユーザ毎の類似度上位3位(同点は同順位)までを取得
SELECT
source_user,
rank,
dest_user,
score
FROM
(
-- JACCARDによる類似度をユーザ毎に順位付けする
SELECT
source_user,
dest_user,
score,
RANK() OVER(PARTITION BY source_user ORDER BY source_user, score DESC) AS rank
FROM
(
-- 比較元ユーザに対して自分自身を除くすべてのユーザとのJACCARD係数を求める
SELECT
users.source_user,
users.dest_user,
JACCARD(
ARRAY(
-- 比較元ユーザの回答をARRAYに変換
SELECT answer
FROM answers
WHERE
user_name = users.source_user
),
ARRAY(
-- 比較先ユーザの回答をARRAYに変換
SELECT answer
FROM answers
WHERE
user_name = users.dest_user
)
) AS score
FROM
(
-- ユーザ毎に自身を除くユーザが結合された仮想テーブルを取得
SELECT
source.user_name AS source_user,
dest.user_name AS dest_user
FROM
answers AS source, answers AS dest
WHERE
-- 自分自身は除く
source.user_name <> dest.user_name
GROUP BY
source.user_name,
dest.user_name
) AS users
) AS result
) AS rank_filter
WHERE
rank <= 3
ORDER BY
source_user,
rank
;
やってやれないこともない。
ピアソン積率相関係数も関数化できたので
次はそちらを。