Scalable Supervised Discrete HashingをNumo::NArrayで実装する


はじめに

こんにちは、Misocaの洋食(@yoshoku)です。この記事は、Misoca+弥生+ALTOA Advent Calendar 2018の9日目の記事です。昨日は、merotanさんの「CSS Houdini をつかってキミだけの最強レイアウトを組み立てよう!」でした。

仕事では、Misocaの開発の他にデータ分析も担当しています。データ分析なエンジニアらしく、機械学習アルゴリズムの実装に挑戦したいと思います。また、MisocaはRubyな会社なので、Rubyの線形代数ライブラリであるNumo::NArrayを使って実装します。Numo::NArrayは、PythonでいうところのNumpyで、だいたい同じことができます

実装するのは、WWW'18で発表されたScalable Supervised Discrete Hashing for Large-Scale Search です。教師ありハッシングにおいて、ロス関数の設定の仕方や行列計算を工夫することで、従来手法よりも高速なハッシュ関数の学習を実現しています。


バイナリハッシングにより近似最近傍探索

近似最近傍探索(Approximate Nearest Neighbor Seach, ANN)は、クエリに最も近いサンプルを探索する手法の一つです。探索するときに、厳密に近傍であることを求めないことで、高速な探索を実現します。ANNの手法には複数種類ありますが、そのなかでもバイナリハッシングによる手法は、データを短いバイナリコードに変換し、ハミング距離を計算することで、クエリの近傍にあるデータを高速に見つけだします。この手法で重要となるのは、データをバイナリコードに変換するハッシング関数です。最も有名なLocality Sensitive Hashingをはじめ、様々な手法が提案されています。


Scalable Supervised Discrete Hashing

論文では、バイナリコードに変換する手法として、Scalable Supervised Discrete Hashing(SSDH)が提案されています。ラベルが与えられているデータセットから、バイナリコードに変換するハッシング関数を学習します。ロス関数の設定を工夫することで、最適化のデータ数の影響を軽減します。


問題設定

ラベルが付与された$n$個の$d$次元のデータ$\mathbf{x}_{i}\in\mathbb{R}^{d}(i=1,2,\ldots,n)$が与えられているとします。ラベルは$c$種類あり、データにどのラベルが付与されているかはベクトル$\mathbf{y}\in\mathbb{B}^{c}$で表現されます。これらデータで構成される行列を、$X\in\mathbb{R}^{n\times d}$および$Y\in\mathbb{B}^{n\times c}$とします。

さらに、データ間の意味的な関連を表す行列$S\in\{-1,1\}^{n\times n}$を導入します。データ$i$とデータ$j$が同じラベルであれば$s_{ij}=1$、異なるラベルであれば$s_{ij}=-1$となります。

SSDHの目的は、各データに対する$r$ビットのバイナリコード$\mathbf{b}_{i}\in\{-1,1\}^{r}$を得ることです。また、バイナリコードによる行列を$B\in\{-1,1\}^{n\times r}$とします。データをバイナリコードに変換するハッシュ関数を$F(X)=\text{sgn}(XW)=B$とします。行列$W\in\mathbb{R}^{d\times r}$は射影行列であり、関数$\text{sgn}(\cdot)$は、要素毎の符号関数であり、$x\geq 0$の場合に$\text{sgn}(x)=1$、それ以外では$-1$となります。

ここで、各データは中心化されており、$\sum_{i=1}^{n}\mathbf{x}_{i}=0$とします。


ロス関数とアルゴリズム

一般的な教師ありハッシングでは、以下のようなロス関数を設定します。ハッシュコードの内積により意味的な関連を近似します。

\min_B ||rS-BB^{\top}||^{2}_{F}

このロス関数による最適化は、データ数$n$が大きくなるほどに、計算時間を必要とすることが知られています。そこでSSDHでは、ラベルベクトルをバイナリコードに変換する射影行列$G\in\mathbb{R}^{c\times r}$を考え、以下のようなロス関数を設定します。

\min_B ||rS-B(YG)^{\top}||^{2}_{F}+\mu||B-YG||^{2}_{F}

以降で説明するように、行列$B$を行列$YG$で置き換えることで、データ数$n$に従って大きくなる行列$S$に対する直接的な最適化を避けることができます。

SSDHのロス関数による最適化では、バイナリコード行列$B$と射影行列$G$を求める必要があります。これは、$B$を固定して$G$を求め、$G$を固定して$B$を求めるということを繰り返すことで実現されます。まず、$G$を求めるため、ロス関数を$G$で偏微分して$0$と置きます。すると以下を得ます。

G=(Y^\top Y)^{-1}(r(SY)^\top B+\mu Y^\top B)(B^\top B+\mu I)^{-1}

ここで、行列$SY$は、データが与えられた時点で計算できます。よって、最適化の繰り返しのなかで、行列$S$を扱うことを避けることができます。また、一般的にラベル数はデータ数よりもずっと小さい($n>>c$)ことから、大きさ$n\times c$の行列$SY$は、行列$S$よりも小さな行列となります。

もう一方の行列$B$は、得られた射影行列$G$と符号関数から求められます。

B=\text{sgn}(YG)

以上の計算を繰り返すことで、最適な行列$G$と$B$を得ます。


ハッシュ関数

欲しいのはデータ$X$をバイナリコード$B$に変換するハッシュ関数$F(X)=\text{sgn}(XW)$です。射影行列$W\in\mathbb{R}^{d\times r}$は、上述の最適化で得られたバイナリコード$B$に対するリッジ回帰により得ます。

\min_{W} ||B-XW||^2_F+\lambda ||W||^2_F

ここで、$\lambda$は正則化パラメータです。行列$W$で偏微分して$0$と置くことで以下を得ます。

W=(X^\top X+\lambda I)^{-1}X^\top B

以上により、ハッシュ関数が求められました。検索クエリのデータ$X_q$のバイナリコード$B_q$も、このハッシュ関数から得ることができます。

B_q = \text{sgn}(X_q W)


実装

Rubyの線形代数ライブラリのNumo::NArrayとNumo::Linalgを用いてSSDHを実装していきます。また、ラベルベクトルを得るために、RubyのScikit-learnライクなライブラリであるSVMKitのOneHotEncoderを用います。全体的に勢いコードですみません。

まずは必要なライブラリをインストールします。

$ brew install openblas --with-openmp

$ gem install numo-narray numo-linalg svmkit

実験に使うデータセットをLIBSVM Dataからダウンロードします。USPSという手書きの数字画像のデータセットを使います。

$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2

$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2
$ bunzip2 usps.bz2
$ bunzip2 usps.t.bz2

Numo::NArrayとNumo::Linalgを使って実装していきましょう。といっても、論文の数式通りに記述していくだけです。まずは学習フェーズ(ハッシュ関数の学習と検索対象のバイナリコード化)から。

require 'numo/narray'

require 'numo/linalg/autoloader'
require 'svmkit'

# 符号関数
def sgn(x)
2 * (Numo::DFloat.cast(x>=0.0)) - 1
end

# 訓練データを読み込む.
mat_x, labels = SVMKit::Dataset.load_libsvm_file('usps')
mat_x = Numo::DFloat.cast(mat_x) # 特徴量が整数の場合Numo::Int32となるためNumo::DFloatにする.
n_samples, n_features = mat_x.shape # サンプル数と特徴数を得る.

# 中心化する.
mean_vec = mat_x.mean(0)
mat_x = mat_x - mean_vec

# ラベルからラベル行列を作る.
# ※ラベルが0から始まらないためLabelEncoderにより0〜9の整数に変えている.
mat_y = SVMKit::Preprocessing::OneHotEncoder.new.fit_transform(
SVMKit::Preprocessing::LabelEncoder.new.fit_transform(labels))
n_classes = mat_y.shape[1] # クラス数を得る.

# パラメータを定義する.
n_bits = 32 # ビット数
n_iters = 10 # 繰り返し数
mu_param = 1.0 # 正則化パラメータ
lambda_param = 1.0 # 正則化パラメータ

# 先に計算できる行列を計算する.
mat_s = 2 * mat_y.dot(mat_y.transpose) - 1
mat_inv_yty = Numo::Linalg.inv(mat_y.transpose.dot(mat_y))
mat_sy = mat_s.dot(mat_y).transpose

# 行列GとBを初期化する.
mat_g = Numo::DFloat.new(n_classes, n_bits).rand_norm
mat_b = sgn(mat_y.dot(mat_g))

# 最適化を実行し行列GとBを計算する.
n_iters.times do
mat_inv_btb = Numo::Linalg.inv(
mat_b.transpose.dot(mat_b) + mu_param * Numo::DFloat.eye(n_bits))
mat_tmp = n_bits * mat_sy.dot(mat_b) + mu_param * (mat_y.transpose.dot(mat_b))
mat_g = mat_inv_yty.dot(mat_tmp).dot(mat_inv_btb)
mat_b = sgn(mat_y.dot(mat_g))
end

# ハッシュ関数の射影行列Wを求める.
mat_tmp = Numo::Linalg.inv(
mat_x.transpose.dot(mat_x) + lambda_param * Numo::DFloat.eye(n_features))
mat_w = mat_tmp.dot(mat_x.transpose).dot(mat_b)

# バイナリコードに変換する.
mat_b = mat_x.dot(mat_w) >= 0 # 真偽値型のNumo::Bitとなる.

続いてテストフェーズ(検索クエリのバイナリコード化と検索)です。

# テストデータを読み込む.

# ※本来テストデータにラベルは必要ない.検索結果を評価するために読み込む.
mat_xq, labels_q = SVMKit::Dataset.load_libsvm_file('usps.t')

# バイナリコードに変換する.
mat_bq = (mat_xq - mean_vec).dot(mat_w) >= 0

# ハミング距離を計算する.
# ※テストデータから適当に1つ検索クエリに選んだ.
query_id = 1
query = mat_bq[query_id, true]
hamming_dists = (query ^ mat_b).count(1)

# ソートして上位の5件を取得する.
top_ids = hamming_dists.sort_index[0...5]
puts "query label: #{labels_q[query_id]}"
top_ids.each do |t|
puts "target label: #{labels[t]}"
end

これを実行すると次のようになります。

query label: 7

target label: 7
target label: 7
target label: 7
target label: 7
target label: 7

うまく同じラベルのものが検索できていますね。


おわりに

論文では、MNISTやCIFAR-10といった有名なデータセットを使って、Mean Average Precisionや学習に要した時間を比較しています。いずれのデータセットでも、SSDHは代表的な従来手法よりも優れた検索精度を得ており、Kernel-based Supervised Hashingなどの教師ありハッシング手法よりも高速に学習できていました。また、論文では、マルチラベルの場合の学習や、カーネル法への拡張についても解説されています。

本稿のように、Numo::NArrayやNumo::Linalgを使用すれば、RubyでもPythonの様に、簡単に機械学習アルゴリズムが実装できます。

明日は、@mugi_unoさんによる 「VSCodeのマルチカーソルの話をかくよ」です!!