論文のまとめをつらつらと書いていきたいと思います。
解釈がおかしい部分があったらコメントください。
論文:IRGAN: A Minimax Game for Unifying Generative and Discriminative Information Retrieval Models
1. 概要
情報検索(Web search, Item recommendation, Question answering)における Generative model と Discriminative model を最適化するために, GAN で用いられている Minmax ゲーム理論を応用した論文. Generative model は文書のサンプリングの分布が, データセットの真の分布に近づく(関連度のある文書をサンプリングできるようにする)ように学習し, 一方, Discriminative model はサンプリングされた文書がデータセットのものか Generative model によってサンプリングされたものか識別する. 実験の結果, Web search, Item recommendation, Question answering のいずれの場合もベースライン(RankNet や LambdaRank, LambdaMART)を上回る結果となった.
2. 背景
情報検索では生成モデル (Generative model) や識別モデル (Discriminative model) の考え方が主流だが, それぞれメリットデメリットが存在する
- Generative model : クエリq が関連するドキュメントdを選択(生成)する確率をモデル化する
- どんな information needs を基に文書が選択されるかとか, 言語モデルを用いて、クエリがその文書に生成される確率を推定する検索モデル(クエリ尤度モデル)など
-> 被リンク数やクリック数などの多くの関連性に関する情報を適切に扱うことが難しい
- Discriminative model : クエリqとドキュメントdが与えられた時に関連するかどうかを識別するモデル
- ランキング学習(Learning to rank)など
-> Unlabelledなデータからは有益な情報が得ることが難しい
したがって, IRGANでは Minmaxゲーム理論によって, Generative model と Discriminative model を同時に学習させることで, これらの問題を同時に改善する
3. IRGAN FORMULATION
3.1 Minmax Retrieval Framework
Generative model : $p_{\theta}(d|q,r)$ (クエリqが与えられた時に, そのクエリの文書集合の中から関連のある文書をサンプリングする確率分布)を $p_{true}(d|q,r)$ (その文書集合の真の関連度分布)に出来るだけ近づけることが目的
Discriminative model : クエリ $q$ における, 関連度のある文書と関連度のない文書の識別を正確に行うことが目的
3.1.1 Overall Objective
Generative model は 関連のある文書をサンプリングすることで Discriminative model が識別できないようにし, 一方, Discriminative model は 文書が Generative model によってサンプリングされたものか, そうでないかを識別するということを繰り返す Minmax ゲームを行う. これを数式で表すと目的関数は以下のようになる.
J^{G^{*},D^{*}}=\underset{\theta}{min}\,\underset{\varphi}{max}\,\mathop{{\sum}^{N}_{n=1}}(\mathbb{E}_{d\sim p_{true}(d|q_{n},r)}[logD(d|q_{n})]+\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[log(1-D(d|q_{n}))])\tag{1}
また, Discriminative model は文書 $d$ がクエリ $q$ と関連がある確率を推定するが, その確率は以下の数式のように Sigmoid 関数を用いて推定される.
D(d|q)=σ(f_{φ}(d,q))=\frac{exp(f_{φ}(d,q))}{1+exp(f_{φ}(d,q))} \tag{2}
それぞれの情報検索タスクによって $f_{\varphi}(d,q)$ が変わってくる(つまりこの部分をタスク毎に変えるだけでそれぞれのタスクに応用可能). 何れにせよ Generative model, Discriminative model は (1) の同じ目的関数をそれぞれ最大化(Discriminative model), 最小化(Generative model)するように学習する.
3.1.2 Optimize Discriminative Retrieval
Discriminative model は以下の式の最大化を行う.
φ^{*}=arg\underset{φ}{max}\,\mathop{{\sum}^{N}_{n=1}}(\mathbb{E}_{d\sim p_{true}(d|q_{n},r)}[log(σ(f_{φ}(d,q_{n})))]+\mathbb{E}_{d\sim p_{θ^{*}}(d|q_{n},r)}[log(1-σ(f_{φ}(d,q_{n})))]) \tag{3}
$f_{\varphi}$ が $\varphi$ に関して微分可能の時, 上式は確率的勾配降下法で最適化できる.
3.1.3 Optimize Generative Retrieval
一方, Generative model は以下の数式の最小化を行う.
\theta^{*}=arg\underset{\theta}{min}\,\mathop{{\sum}^{N}_{n=1}}(\mathbb{E}_{d\sim p_{true}(d|q_{n},r)}[log(\sigma(f_{\varphi}(d,q_{n})))]+\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[log(1-\sigma(f_{\varphi}(d,q_{n})))])
上式において Generative model の更新を行うときは, 右項だけ考えればいい. (左項はデータセットの文書の場合の期待値を示し, 右項は Generative model によってサンプリングされた文書の場合の期待値). したがって上式は以下のように変形できる
\theta^{*}=arg\underset{\theta}{min}\,\mathop{{\sum}^{N}_{n=1}}(\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[log(1-\sigma(f_{\varphi}(d,q_{n})))])
(2)より,
=arg\underset{\theta}{min}\,\mathop{{\sum}^{N}_{n=1}}(\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[log(1-\frac{exp(f_{\varphi}(d,q))}{1+exp(f_{\varphi}(d,q))})])
=arg\underset{\theta}{min}\,\mathop{{\sum}^{N}_{n=1}}(\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[log(\frac{1}{1+exp(f_{\varphi}(d,q))})])
したがって, 以下のように書き換えることができる,
\theta^{*}=arg\underset{\theta}{max}\,\mathop{{\sum}^{N}_{n=1}}(\underset{denoted\,as\,J^{G}(q_{n})}{\underbrace{\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[log(1+exp(f_{\varphi}(d,q_{n})))]}}) \tag{4}
IRGAN の場合, Generative model は GAN の時と異なり, 文書を生成するわけではなく, $p_{\theta}(d|q,r)$ に基づいて文書をサンプリングするので, 直接勾配降下法で最適化できない. したがって, 強化学習を用いた方策勾配法で最適化を行う. 1クエリ毎の目的関数の勾配は以下のように導出される.
\nabla_{\theta}J^{G}(q_{n})=\nabla_{\theta}\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[log(1+exp(f_{\varphi}(d,q_{n})))]
離散型の確率変数の場合, $\mathbb{E}(X)={\displaystyle \sum_{i}x_{i}f(x_{i})}$ と表せるので,
={\sum}^{M}_{i=1}\nabla_{\theta}p_{\theta}(d_{i}|q_{n,r})log(1+exp(f_{\varphi}(d,q_{n})))
$\nabla_{\theta}p_{\theta}(d_{i}|q_{n,r})=p_{\theta}(d_{i}|q_{n,r})\cdot\nabla_{\theta}logp_{\theta}(d_{i}|q_{n},r)$ なので,
={\sum}^{M}_{i=1}p_{\theta}(d_{i}|q_{n,r})\nabla_{\theta}logp_{\theta}(d_{i}|q_{n},r)log(1+exp(f_{\varphi}(d,q_{n})))
=\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[\nabla_{\theta}logp_{\theta}(d_{i}|q_{n},r)log(1+exp(f_{\varphi}(d,q_{n})))]
$log(1+exp(f_{\varphi}(d,q_{n})))$ は強化学習における報酬, クエリは環境, 文書のサンプリングは行動と対応付けられる. K回サンプリングした報酬の予測値の平均で期待値を近似する(モンテカルロ法)と勾配は以下のように表すことができる.
\simeq\frac{1}{K}{\sum}^{K}_{k=1}\nabla_{\theta}logp_{\theta}(d_{i}|q_{n},r)log(1+exp(f_{\varphi}(d,q_{n})))\tag{5}
この勾配を基に最急降下法でパラメータを最適化する.本論文では学習を収束させるために報酬 $log(1+exp(f_{\varphi}(d,q_{n})))$ を以下の数式のように置き換えている.
log(1+exp(f_{\varphi}(d,q_{n})))-\mathbb{E}_{d\sim p_{\theta}(d|q_{n},r)}[log(1+exp(f_{\varphi}(d,q_{n})))]
また, IRGANでは, Discriminative model, Generative model は教師データを用いて事前学習させてから, (3)や(5)のように学習を行う. IRGANの全体のアルゴリズムをまとめると以下の表のようになる.
3.2 Extension to Pairwise Case
IRGANはpairwiseにも対応可能である. クエリ $q_{n}$ における文書のペアを $R_{n}={\left\langle d_{i},d_{j}\right\rangle |d_{i}\succ d_{j}}$ とする. $d_{i}\succ d_{j}$は文書 $i$ の方が文書 $j$ よりもクエリ $q_{n}$ に対して関連度が高いことを表す. この時, Generative model は $R_{n}$ のような文書のペアをサンプリングする. Discriminative model も Pointwise の時と同様で以下の数式のように Sigmoid 関数を用いてそのペアが Generative model によってサンプリングされたものかどうか識別する.
D(\left\langle d_{u},d_{v}\right\rangle |q)=\sigma(f_{\sigma}(d_{u},q)-f_{\sigma}(d_{v},q))
=\frac{exp(f_{\sigma}(d_{u},q)-f_{\sigma}(d_{v},q))}{1+exp(f_{\sigma}(d_{u},q)-f_{\sigma}(d_{v},q))}
$z=f_{\sigma}(d_{u},q)-f_{\sigma}(d_{v},q)$ とすると以下のように表すことができる.
=\frac{1}{1+exp(-z)}\tag{6}
したがって, これに $-log$ をかけると以下のようにRankNetで使われている損失関数となる.
-logD(\left\langle d_{u},d_{v}\right\rangle |q)=log(1+exp(-z))
そして, Cross entropy を用いると Pairwise における目的関数は以下のようになる.
J^{G^{*},D^{*}}=\underset{\theta}{min}\,\underset{\varphi}{max}\,\mathop{{\sum}_{n=1}^{N}}(\mathbb{E}_{o\sim p_{true}(o|q_{n})}[logD(o|q_{n})]+\mathbb{E}_{o'\sim p_{\theta}(o'|q_{n})}[log(1-D(o'|q_{n}))])\tag{7}
ここで, $o=\left\langle d_{i},d_{j}\right\rangle$ , $o'=\left\langle d_{i}',d_{j}'\right\rangle$ . 実際に Generative model が文書のペアをサンプリングする時, はじめに $R_{n}$ から文書のペア $\left\langle d_{i},d_{j}\right\rangle$ を取ってきて, 関連度が低い $d_{j}$ とペアとなるように Generative model は文書集合の中から文書 $d_{k}$ をサンプリングし, ペア $\left\langle d_{k},d_{j}\right\rangle$ を作る. $d_{k}$ が $d_{j}$ よりも関連度の高い文書となるように Generative model を学習させていく. Generative model の確率分布は以下のように Softmax 関数の形で表される.
p_{\theta}(d_{k}|q,r)=\frac{exp(g_{\theta}(q,d))}{\sum_{d}exp(g_{\theta}(q,d))}\tag{8}
この $g_{\theta}(q,d)$ も情報検索のタスクによって変わり, Pointwoseの時と同様で強化学習を用いて学習させる.
3.3 Discussion
もし $p_{true}(d|q,r)$(その文書集合の真の関連度分布)が分かっていたら IRGAN の minmax game はナッシュ均衡の状態となる. つまり, Generative model は真の関連度分布を基に文書をサンプリングできるようになり($p_{\theta}(d|q,r)=p_{true}(d|q,r)$), Discriminative model は文書が Generative model からきたものか, そうでないかを識別できなくなる. しかし, 実際には真の関連度分布はわからないので, その均衡にどうやって収束させるかが最近の研究における問題となっている. IRGANでは, タスクによって Generative model と Discriminative model の精度に差が生まれることが確認された. IRGANの学習の様子を図で表したのが以下の図である. Observed sample は既にサンプリングされた文書を表し. Unobserved sample はまだサンプリングされていない文書を表す. ここで, Generative model, Discriminative model は, 以下のような役割を持つ.
-
Generative model : Discriminative model を騙すため(学習を困難にするため)に, Generated sample, Unobserved sample(positive, negative)を Discriminative model の決定境界( Decision boundary )まで押し上げる・近づける
-
Discriminative model : Generative model によってサンプリングされた文書を決定境界から押し下げる
Observed positive sample と Unobserved positive sample の間には正の相関があるから, 学習が進むとGenerative model は 他のサンプルよりも Unobserved positive sample を押し上げられるようになる. その結果, Generative model の確率分布が真の関連度分布に近づく.

3.4 Link to Existing Work
3.4.1 Generative Adversarial Networks
IRGANのGANとの主な違いは以下の4点である.
-
GANでは連続的なデータである画像を扱っていたが, IRGANにおいて Generative model は離散的なデータから確率的なサンプリングを行う
-
Generative model の学習に強化学習のアルゴリズムである方策勾配法を用いている
-
GANでは学習に使うデータは連続的で特徴空間が無限であるが, IRGANでは, 離散的で有限の特徴空間である
-
情報検索に特有のPairwise手法の目的関数を提案した
3.4.2 MLE based Retrieval Models
情報検索モデルにおいて, 与えられたデータからそれが従う確率分布を推定する最尤推定がよく使われている. IRGAN の Generative model では, 最尤推定に比べ高精度の確率分布の推定が可能で, Discriminative modelでは, 最尤推定と異なり, ラベルがついていないデータも利用することができるので, 半教師あり学習も可能.
3.4.3 Noise-Contrastive Estimation
ノイズデータと真のデータの識別を行う Noise-Contrastive Estimation という技術があるが, これに比べ, IRGANは 二つのモデルを同時に学習させることが可能.
4 APPLICATIONS
文書のサンプリングにおいて, (8)式に Temperature parameter $\tau$ が組み込まれる.
p_{\theta}(d_{k}|q,r)=\frac{exp(g_{\theta}(q,d)/\tau)}{\sum_{d}exp(g_{\theta}(q,d)/\tau)}\tag{9}
この $\tau$ の値を低くすると, より Generative model は関連度の高い文書をサンプリングしてくることに焦点を当てる. (2) と (8) のスコア関数 $f_{\varphi}(q,d)$ , $g_{\theta}(q,d)$ は, 情報検索タスクによって異なり, 様々な形をとるが, ここでは簡略化のためにこの二つの関数を同じ関数で表す.
g_{\theta}(q,d)=s_{\theta}(q,d)\:and\:f_{\varphi}(q,d)=s_{\varphi}(q,d)
今回はWeb Searchの部分だけ見ていきます.
4.1 Web Search
ウェブ検索において, クエリ-文書ペア $(q,d)$ は $\mathbf{x}_{q,d}\in\mathbb{R}^{k}$ というベクトルで表される. IRGANではスコア関数はRankNetに倣って, 以下のような2層のニューラルネットワークを定義した.
s(q,d)=\mathbf{w}_{2}^{\top}tanh(\mathbf{W}_{1}\mathbf{x}_{q,d}+\mathbf{b}_{1})+w_{0}
ここで,
$\mathbf{W}_{1}\in\mathbb{R}^{l\times k}$ :1層目の全結合層
$\mathbf{b}_{1}\in\mathbb{R}^{l}$ :中間層のバイアス
$\mathbf{w}_{2}\in\mathbb{R}^{l}$ :出力層の重み
$\mathbf{w}_{0}$ :出力層の重み
5 EXPERIMENTS
5.1 Web search
データセットには半教師あり学習で使われるデータセットのMQ2008-semiを用いた. このデータセットでは. -1 $\sim$ 2の関連度があり, -1 は Unknownを表す. 特徴量は46あり, 本実験では, 関連度が -1, 0 文書をUnlabelled sample としている. 全部で 784 個クエリがあるが, それぞれ一つのクエリに対して, 関連度がある文書が 5 個, ラベルがついていない文書が約 1000 個紐づけられている. 比較対象には, RankNet, LambdaRank, そして, とても精度が高いLambdaMARTを用いていて, 評価指標には, Precision, NDCG, MAP, MRR を用いている. IRGAN-pointwise では, Generative model で予測を行い, IRGAN-pairwise では, Discriminative modelで予測を行なっている.
Result and Discussion
学習が完了した IRGAN の Generative model を基に文書のランキングを生成し, 他のベースラインと比較したところ以下の結果が得られた. 全ての場合において, ベースラインよりも IRGAN の方が評価指標の値が高くなっている. IRGAN-pairwise, IRGAN-pointwiseを見てみると, Precision@3, NDCG@3 の場合, IRGAN-pairwise が IRGAN-pointwise よりもスコアが高くなった. これは, Pairwise 手法の方が, クエリとペアになっている文書全体のランキングを考慮できているからであると推測される.
Adversarial training は一般的に学習が不安定(収束しづらい)と言われているので, Pointwise, Pairwise におけるGenerative model, Discriminative model の学習の様子を見てみたところ以下の結果が得られた. Poinwise では 150epoch を過ぎたあたり, Pairwise では 60epoch を過ぎたあたりで best のベースラインである LambdaRank を超えていることがわかる.
次に, (9) の $\tau$ の最適な値を調査するために $\tau$ の値を 0 から少しづつ増加させて評価指標の値の変化をみてみたところ以下の結果が得られた. Precision においても NDCG においても $\tau$=0.2 の時, 評価指標の値が最も高くなることが確認された.
最後に, Generative model とDiscriminative model のモデルの構造の複雑さが $f_{\varphi}(q,d)$ , $g_{\theta}(q,d)$ に与える影響を調査したところ図7のような結果が得られた. 本実験では Pointwise, Pairwise において, Generative model とDiscriminative model をそれぞれ, 線形モデル, 2層のニューラルネットワーク(NN)に変えて調査を行なった. その結果, Pointwise の場合は, Generative model を NN にした時精度が高くなったが, Discriminative model を NN にしても精度が高くなるわけではないことが確認された. 一方, Pairwise では, Discriminative model を NN にした時精度が高くなった. したがって, どちらの手法においても, 予測に用いるモデルをもう一方のモデルよりも複雑にすることで精度が高くなることが判明した.
6. 感想
- Pointwise, Pairwiseによって予測に用いる最適なモデル(Generative model, Discriminative model)が変わる部分が面白いなと感じました
- Listwiseでも敵対的学習できそうですね