4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Scalable Supervised Discrete HashingをNumo::NArrayで実装する

Last updated at Posted at 2018-12-08

はじめに

こんにちは、Misocaの洋食(@yoshoku)です。この記事は、Misoca+弥生+ALTOA Advent Calendar 2018の9日目の記事です。昨日は、merotanさんの「CSS Houdini をつかってキミだけの最強レイアウトを組み立てよう!」でした。

仕事では、Misocaの開発の他にデータ分析も担当しています。データ分析なエンジニアらしく、機械学習アルゴリズムの実装に挑戦したいと思います。また、MisocaはRubyな会社なので、Rubyの線形代数ライブラリであるNumo::NArrayを使って実装します。Numo::NArrayは、PythonでいうところのNumpyで、だいたい同じことができます

実装するのは、WWW'18で発表されたScalable Supervised Discrete Hashing for Large-Scale Search です。教師ありハッシングにおいて、ロス関数の設定の仕方や行列計算を工夫することで、従来手法よりも高速なハッシュ関数の学習を実現しています。

バイナリハッシングにより近似最近傍探索

近似最近傍探索(Approximate Nearest Neighbor Seach, ANN)は、クエリに最も近いサンプルを探索する手法の一つです。探索するときに、厳密に近傍であることを求めないことで、高速な探索を実現します。ANNの手法には複数種類ありますが、そのなかでもバイナリハッシングによる手法は、データを短いバイナリコードに変換し、ハミング距離を計算することで、クエリの近傍にあるデータを高速に見つけだします。この手法で重要となるのは、データをバイナリコードに変換するハッシング関数です。最も有名なLocality Sensitive Hashingをはじめ、様々な手法が提案されています。

Scalable Supervised Discrete Hashing

論文では、バイナリコードに変換する手法として、Scalable Supervised Discrete Hashing(SSDH)が提案されています。ラベルが与えられているデータセットから、バイナリコードに変換するハッシング関数を学習します。ロス関数の設定を工夫することで、最適化のデータ数の影響を軽減します。

問題設定

ラベルが付与された$n$個の$d$次元のデータ$\mathbf{x}_{i}\in\mathbb{R}^{d}(i=1,2,\ldots,n)$が与えられているとします。ラベルは$c$種類あり、データにどのラベルが付与されているかはベクトル$\mathbf{y}\in\mathbb{B}^{c}$で表現されます。これらデータで構成される行列を、$X\in\mathbb{R}^{n\times d}$および$Y\in\mathbb{B}^{n\times c}$とします。

さらに、データ間の意味的な関連を表す行列$S\in\{-1,1\}^{n\times n}$を導入します。データ$i$とデータ$j$が同じラベルであれば$s_{ij}=1$、異なるラベルであれば$s_{ij}=-1$となります。

SSDHの目的は、各データに対する$r$ビットのバイナリコード$\mathbf{b}_{i}\in\{-1,1\}^{r}$を得ることです。また、バイナリコードによる行列を$B\in\{-1,1\}^{n\times r}$とします。データをバイナリコードに変換するハッシュ関数を$F(X)=\text{sgn}(XW)=B$とします。行列$W\in\mathbb{R}^{d\times r}$は射影行列であり、関数$\text{sgn}(\cdot)$は、要素毎の符号関数であり、$x\geq 0$の場合に$\text{sgn}(x)=1$、それ以外では$-1$となります。

ここで、各データは中心化されており、$\sum_{i=1}^{n}\mathbf{x}_{i}=0$とします。

ロス関数とアルゴリズム

一般的な教師ありハッシングでは、以下のようなロス関数を設定します。ハッシュコードの内積により意味的な関連を近似します。

\min_B ||rS-BB^{\top}||^{2}_{F}

このロス関数による最適化は、データ数$n$が大きくなるほどに、計算時間を必要とすることが知られています。そこでSSDHでは、ラベルベクトルをバイナリコードに変換する射影行列$G\in\mathbb{R}^{c\times r}$を考え、以下のようなロス関数を設定します。

\min_B ||rS-B(YG)^{\top}||^{2}_{F}+\mu||B-YG||^{2}_{F}

以降で説明するように、行列$B$を行列$YG$で置き換えることで、データ数$n$に従って大きくなる行列$S$に対する直接的な最適化を避けることができます。

SSDHのロス関数による最適化では、バイナリコード行列$B$と射影行列$G$を求める必要があります。これは、$B$を固定して$G$を求め、$G$を固定して$B$を求めるということを繰り返すことで実現されます。まず、$G$を求めるため、ロス関数を$G$で偏微分して$0$と置きます。すると以下を得ます。

G=(Y^\top Y)^{-1}(r(SY)^\top B+\mu Y^\top B)(B^\top B+\mu I)^{-1}

ここで、行列$SY$は、データが与えられた時点で計算できます。よって、最適化の繰り返しのなかで、行列$S$を扱うことを避けることができます。また、一般的にラベル数はデータ数よりもずっと小さい($n>>c$)ことから、大きさ$n\times c$の行列$SY$は、行列$S$よりも小さな行列となります。

もう一方の行列$B$は、得られた射影行列$G$と符号関数から求められます。

B=\text{sgn}(YG)

以上の計算を繰り返すことで、最適な行列$G$と$B$を得ます。

ハッシュ関数

欲しいのはデータ$X$をバイナリコード$B$に変換するハッシュ関数$F(X)=\text{sgn}(XW)$です。射影行列$W\in\mathbb{R}^{d\times r}$は、上述の最適化で得られたバイナリコード$B$に対するリッジ回帰により得ます。

\min_{W} ||B-XW||^2_F+\lambda ||W||^2_F

ここで、$\lambda$は正則化パラメータです。行列$W$で偏微分して$0$と置くことで以下を得ます。

W=(X^\top X+\lambda I)^{-1}X^\top B

以上により、ハッシュ関数が求められました。検索クエリのデータ$X_q$のバイナリコード$B_q$も、このハッシュ関数から得ることができます。

B_q = \text{sgn}(X_q W)

実装

Rubyの線形代数ライブラリのNumo::NArrayとNumo::Linalgを用いてSSDHを実装していきます。また、ラベルベクトルを得るために、RubyのScikit-learnライクなライブラリであるSVMKitのOneHotEncoderを用います。全体的に勢いコードですみません。

まずは必要なライブラリをインストールします。

$ brew install openblas --with-openmp
$ gem install numo-narray numo-linalg svmkit

実験に使うデータセットをLIBSVM Dataからダウンロードします。USPSという手書きの数字画像のデータセットを使います。

$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2
$ bunzip2 usps.bz2
$ bunzip2 usps.t.bz2

Numo::NArrayとNumo::Linalgを使って実装していきましょう。といっても、論文の数式通りに記述していくだけです。まずは学習フェーズ(ハッシュ関数の学習と検索対象のバイナリコード化)から。

require 'numo/narray'
require 'numo/linalg/autoloader'
require 'svmkit'

# 符号関数
def sgn(x)
  2 * (Numo::DFloat.cast(x>=0.0)) - 1
end

# 訓練データを読み込む.
mat_x, labels = SVMKit::Dataset.load_libsvm_file('usps')
mat_x = Numo::DFloat.cast(mat_x) # 特徴量が整数の場合Numo::Int32となるためNumo::DFloatにする.
n_samples, n_features = mat_x.shape # サンプル数と特徴数を得る.

# 中心化する.
mean_vec = mat_x.mean(0)
mat_x = mat_x - mean_vec

# ラベルからラベル行列を作る.
# ※ラベルが0から始まらないためLabelEncoderにより0〜9の整数に変えている.
mat_y = SVMKit::Preprocessing::OneHotEncoder.new.fit_transform(
          SVMKit::Preprocessing::LabelEncoder.new.fit_transform(labels))
n_classes = mat_y.shape[1] # クラス数を得る.

# パラメータを定義する.
n_bits = 32 # ビット数
n_iters = 10 # 繰り返し数
mu_param = 1.0 # 正則化パラメータ
lambda_param = 1.0 # 正則化パラメータ

# 先に計算できる行列を計算する.
mat_s = 2 * mat_y.dot(mat_y.transpose) - 1
mat_inv_yty = Numo::Linalg.inv(mat_y.transpose.dot(mat_y))
mat_sy = mat_s.dot(mat_y).transpose

# 行列GとBを初期化する.
mat_g = Numo::DFloat.new(n_classes, n_bits).rand_norm
mat_b = sgn(mat_y.dot(mat_g))

# 最適化を実行し行列GとBを計算する.
n_iters.times do
  mat_inv_btb = Numo::Linalg.inv(
    mat_b.transpose.dot(mat_b) + mu_param * Numo::DFloat.eye(n_bits))
  mat_tmp = n_bits * mat_sy.dot(mat_b) + mu_param * (mat_y.transpose.dot(mat_b))
  mat_g = mat_inv_yty.dot(mat_tmp).dot(mat_inv_btb)
  mat_b = sgn(mat_y.dot(mat_g))
end

# ハッシュ関数の射影行列Wを求める.
mat_tmp = Numo::Linalg.inv(
  mat_x.transpose.dot(mat_x) + lambda_param * Numo::DFloat.eye(n_features))
mat_w = mat_tmp.dot(mat_x.transpose).dot(mat_b)

# バイナリコードに変換する.
mat_b = mat_x.dot(mat_w) >= 0 # 真偽値型のNumo::Bitとなる.

続いてテストフェーズ(検索クエリのバイナリコード化と検索)です。

# テストデータを読み込む.
# ※本来テストデータにラベルは必要ない.検索結果を評価するために読み込む.
mat_xq, labels_q = SVMKit::Dataset.load_libsvm_file('usps.t')

# バイナリコードに変換する.
mat_bq = (mat_xq - mean_vec).dot(mat_w) >= 0

# ハミング距離を計算する.
# ※テストデータから適当に1つ検索クエリに選んだ.
query_id = 1
query = mat_bq[query_id, true]
hamming_dists = (query ^ mat_b).count(1)

# ソートして上位の5件を取得する.
top_ids = hamming_dists.sort_index[0...5]
puts "query label: #{labels_q[query_id]}"
top_ids.each do |t|
  puts "target label: #{labels[t]}"
end

これを実行すると次のようになります。

query label: 7
target label: 7
target label: 7
target label: 7
target label: 7
target label: 7

うまく同じラベルのものが検索できていますね。

おわりに

論文では、MNISTやCIFAR-10といった有名なデータセットを使って、Mean Average Precisionや学習に要した時間を比較しています。いずれのデータセットでも、SSDHは代表的な従来手法よりも優れた検索精度を得ており、Kernel-based Supervised Hashingなどの教師ありハッシング手法よりも高速に学習できていました。また、論文では、マルチラベルの場合の学習や、カーネル法への拡張についても解説されています。

本稿のように、Numo::NArrayやNumo::Linalgを使用すれば、RubyでもPythonの様に、簡単に機械学習アルゴリズムが実装できます。

明日は、@mugi_unoさんによる 「VSCodeのマルチカーソルの話をかくよ」です!!

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?