Machine Learning Advent Calendar向けの記事です。
はじめに
最近、個人的にGo言語を触ることがちょいちょいあります。
型があって割と高速に動いてくれて、ポータビリティの高いとこが気に入ってるのですが、何十万・何百万人に推薦データを提供することが使命な会社にいるなら、これを推薦に使おうかなと画策しています。
というわけで今回は、周囲が皆サーベイ論文紹介とかなので若干毛色違いますが、Goでb-bit Minwise Hashingを実装したことについて書きます。
b-bit Minwise Hashingとは
Minhash、及びb-bit Minhashについては、既にPFIの岡野原さんによる素晴らしいPostがあるため、詳しく知りたい方はそちら+その中で紹介されている論文を読んでいただくのが一番かと思います。
参考: MinHashによる高速な類似検索
http://research.preferred.jp/2011/02/minhash/
この投稿では簡単にb-bit Minhashまでさらっと解説するにとどめます。
Jaccard係数
文書間などの類似度を計算する上で指標となる係数はいくつかありますが、Jaccard係数はその一つです。
「2つのデータ間の共起した要素/2つのデータが含む全要素」として表される係数となります。
sim(a, b) = \frac{|a \cap b|}{|a \cup b|}
Minhash
Jaccard係数の計算においては、データ小さければそれほど問題にならないのですが、計算対象が数千万件を超えてくるとデータ自体の保存やその効率に関して問題が出てきます。
この問題に対して以下の特殊な性質を使って計算を効率化しようというのがMinhashになります。
あるハッシュ関数を2つのデータの全要素に対して適用した時、各データの最小のハッシュ値が一致する確率はJaccard係数に等しい
Sim(a,b) = P(min\{h(d_a)| d_a \in a\} = min\{h(d_b)| d_b \in b\})
なぜこういった関係が成り立つのかは先のブログを参照して下さい。
この性質を用いると、予め幾つかのハッシュ関数を用意して、各データにおける最小値を記録しておくことで、その最小値のセットをインデックス的に用いてJaccard係数の算出を高速、かつ省スペースにすることが出来ます。
ただし、それでもハッシュ値が衝突しにくい関数を用いると各々のハッシュ値のサイズは例えば64bit程度だったとしてもデータが数千万件などになるとサイズ的には無視できない大きさになります。
b-bit Minhash
そこで、Minhashに用いるハッシュ値をサイズの小さなものでも利用可能にしたのがb-bit Minhashです。
b-bit Minhashでは、値域の小さなハッシュ関数を用い、そのハッシュ値の衝突確率を使って、Minhashの値を補正します。
実際の補正式は非常に簡単で、
Similarity = P(ハッシュ値の最小値が一致する確率) - P(ハッシュ値の衝突確率)
なぜこういった関係が成り立つのかは(ry
小さなハッシュが使えるので、例えば1bitのハッシュ値を使って計算を効率化する、なんてことも可能です。
1bitであれば、例えば64個のハッシュ関数を用いた場合、計算結果を64bit値で表現して、各々のXORからのPopcountで最小値が一致した数を効率よく算出できます。
実装してみた
簡単に説明したところで、実際の実装の話をします。
今回の実装ですが、ハッシュ値は1bit、32bitのMurmurhash3の下位1bitを用いています。
Murmurhash3自体も今回は趣味で実装しました。
まずはMurmurhash3です。
package mmh
import (
"bytes"
"encoding/binary"
)
var mask32 = uint32(0xffffffff)
func rotl(x, r uint32) uint32 {
return ((x << r) | (x >> (32 - r))) & mask32
}
func mmix(h uint32) uint32 {
h &= mask32
h ^= h >> 16
h = (h * 0x85ebca6b) & mask32
h ^= h >> 13
h = (h * 0xc2b2ae35) & mask32
return h ^ (h >> 16)
}
func Murmurhash3_32(key string, seed uint32) uint32 {
var h1 uint32 = seed
var k uint32
var c1, c2 uint32 = 0xcc9e2d51, 0x1b873593
buffer := bytes.NewBufferString(key)
keyBytes := []byte(key)
length := buffer.Len()
if length == 0 {
return 0
}
nblocks := length / 4
for i := 0; i < nblocks; i++ {
binary.Read(buffer, binary.LittleEndian, &k)
k *= c1
k = rotl(k, 15)
k *= c2
h1 ^= k
h1 = rotl(h1, 13)
h1 = h1*5 + 0xe6546b64
}
var k1 uint32 = 0
tail := nblocks * 4
switch length & 3 {
case 3:
k1 ^= uint32(keyBytes[tail+2]) << 16
fallthrough
case 2:
k1 ^= uint32(keyBytes[tail+1]) << 8
fallthrough
case 1:
k1 ^= uint32(keyBytes[tail])
k *= c1
k = rotl(k, 15)
k *= c2
h1 ^= k
}
// finalize
h1 ^= uint32(length)
return mmix(h1)
}
そしてこちらが実際のb-bit Minhashの実装。
package main
import (
"./mmh"
"fmt"
"math"
"math/big"
"math/rand"
)
var bitMask = uint32(0x1)
func minKey(l map[string]uint32) (string, uint32) {
var result string
m := uint32(math.MaxUint32)
for k := range l {
if m > l[k] {
m = l[k]
result = k
}
}
return result, m
}
func minHash(data []string, seed uint32) uint32 {
vector := make(map[string]uint32)
for k := range data {
vector[data[k]] = mmh.Murmurhash3_32(data[k], seed)
}
_, value := minKey(vector)
return value
}
func signature(data []string) uint32 {
rand.Seed(1)
sig := uint32(0)
for i := 0; i < 128; i++ {
sig += (minHash(data, rand.Uint32()) & bitMask) << uint32(i)
}
return sig
}
func signatureBig(data []string) *big.Int {
rand.Seed(1)
sigBig := big.NewInt(0)
for i := 0; i < 128; i++ {
sigBig.SetBit(sigBig, i, uint(minHash(data, rand.Uint32())&bitMask))
}
return sigBig
}
func popCount(bits uint32) uint32 {
bits = (bits & 0x55555555) + (bits >> 1 & 0x55555555)
bits = (bits & 0x33333333) + (bits >> 2 & 0x33333333)
bits = (bits & 0x0f0f0f0f) + (bits >> 4 & 0x0f0f0f0f)
bits = (bits & 0x00ff00ff) + (bits >> 8 & 0x00ff00ff)
return (bits & 0x0000ffff) + (bits >> 16 & 0x0000ffff)
}
func popCountBig(bits *big.Int) int {
result := 0
for _, v := range bits.Bytes() {
result += int(popCount(uint32(v)))
}
return result
}
func calcJaccard(v1, v2 []string) float32 {
commonBig := big.NewInt(0)
commonBig.Xor(signatureBig(v1), signatureBig(v2))
return 2.0 * (float32((128.0-popCountBig(commonBig)))/128.0 - 0.5)
}
func calc() {
fmt.Println(calcJaccard([]string{
"21", "歳", "ビール", "飲む", "アサヒ", "リクルート", "連携",
}, []string{
"21", "歳", "ビール", "無料", "アサヒ", "若者", "向け", "企画",
}))
fmt.Println(calcJaccard([]string{
"21", "歳", "ビール", "飲む", "アサヒ", "リクルート", "連携",
}, []string{
"サンクト", "ガーレン", "バレンタイン", "向け", "チョコレート", "風味", "ビール", "4", "種", "発売",
}))
}
func main() {
calc();
}
上のコードでは何となくビールで検索して引っかかったニュースのうち、関係有るもの2タイトルと関係ないもの1タイトルを使ってjaccard係数を計算してみたもの。
出力結果:
$ go run minhash.go
0.375
0.28125
一応は近いタイトルほど係数が大きい=似ていると判定はできているっぽいです。
ただ、似たデータでも例えば全角と半角が混ざってれば遠いデータと判定されますし、事前のデータを如何に綺麗に用意できるかの方が重要です。
あと、上記のコードだと、計算したハッシュ値が保存されていないので、あまりb-bit Minhashの恩恵が受けられるわけでも無いですね。
今後はkey-valueストアチックなAPIを備えた簡易的な近似データ算出DB的なのをつくろうかと思っています。
まとめ
b-bit Minhashすごい。