はじめに
NIPS2015にて採択された "Spatial Transformer Networks" について読みましたので、まとめていこうと思います。
Google DeepMindに所属する方々による論文であり、大雑把な内容としては、画像の空間的補正(アフィン変換など)のパラメータを予測するネットワークを全体のネットワークに組み込むことで、BackPropにより画像補正のパラメータを適切に予測することができるという内容の論文です。
これにより、例えば入力画像に歪みが生じていたり、対象物の周囲の景色なども写っているようなデータに対して、うまく対象物のみを切り出し、対象物の姿勢を修正してからネットワークへと入力を流し込むことができるので、予測精度の向上が見込まれます。
本記事では実装ではなく論文自体を理解することを目的に説明していきます。なので英語論文をさくさく読める方はそちらを読んだほうが良いかと思われます。
基本的に論文の構成のまま、関連研究など本筋に無関係なもの以外については省略せずに説明していきます。
Tips
- Title
- Spatial transformer networks [論文リンク]
- Conference
- NIPS (Advances in neural information processing systems) 2015
- Authors
- Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
0. Abstract
従来のCNN (Convolutional Neural Networks) では入力画像として画像全体を入力し、その画像に対して予測値を出力する形式でした。なので、例えば画像全体のうち、予測に必要な部分は一部だけであるような画像(例:図1)に対しては、余分な箇所が予測に悪影響を及ぼす可能性があります。他にも、入力画像に含まれる対象物が傾いている場合や、歪んでいる可能性も存在します。
図1: '5'と書かれている画像。このような画像の場合、従来のCNNでは余分な黒枠の部分もCNNに入力するので、その分予測性能が低下する恐れが存在します。なのでできれば'5'の存在している箇所あたりを切り取ってからCNNに入力したいと思われます。これに対して筆者らは STNs (Spatial Transformer Networks) というネットワークを導入します。STNsは、CNNに画像を入力する前に画像を自動で変換(移動・縮小・回転・切り取りなど)することで、画像の歪みを修正し、予測に必要な部分のみをネットワークに流すようなアーキテクチャです。STNsを用いることでより頑健性の高いモデルを構築することができます。
1. Introduction
近年特にコンピュータビジョンの分野ではCNNについての研究が盛んであり、分類・識別・Semantic Segmentation (画素レベルでの画像認識)・映像認識などなどの多彩な分野で活用されています。
画像を対象とするタスクとしては、画像に映る物体の姿勢を検知し、画像に変形を加えることで物体の歪みを解消することが重要です。CNNでは Max-Poolingレイヤーを導入することで空間的な不変性を実現していますが、Max-Poolingレイヤーは厳密には入力データに近い、浅い層での大きな変換に対しては不変ではありません。
そこで本研究により Spatial Transformer モジュールを導入することにより、空間的不変性を解消します。これは通常のニューラルネットワークに空間的歪みを解消する能力を与えるもので、特殊な教師データを用いずに入力データのみから空間的変形モジュールの学習を行います。学習も通常の BackProp で行うことができます。Max-Poolingとは違い画像全体に作用することで、画像から必要な領域を選択し、姿勢を修正して後続のニューラルネットへ繋ぐことができます。
図2: 歪んだMNISTデータに対する実験。画像から認識に必要な箇所のみを切り出し、その箇所の空間的歪みを補正しています。(a)から(c)への変換が Spatial Transformer であり、歪みを補正したのちに後続の CNN へとデータを流しています。
2. Related Work
省略
3. Spatial Transformers
本章でSTNsの定式化について説明していきます。前提として、STNsは入力として複数のチャンネルを持つ場合、各チャンネルに同一の変形を施します(実際、RGBの3チャンネルを入力したときにRチャンネルとGチャンネルとBチャンネルでそれぞれ別の変換をするのは不自然ですよね?)。また問題を簡単にするために以降の説明では単一の変換を適用する場合について説明していきます。
STNsは大まかに3つの機構に分かれています(図3)。これら3つの機構それぞれについて3.1~3.3節にて説明した後、全体について3.4節にて説明していきます。
図3: STNsの概略図
- Localisation Net: 入力された特徴空間 $U$ に対し、特徴空間を変換するための変数 $\theta$を出力するネットワーク。
- Grid Generator: 変数 $\theta$ を元にして、サンプリンググリッド $G$ (変換前の画像の各ピクセルが変換後にどの座標にくるか対応付けたグリッド) を作成します。
- Image Sampler: 特徴空間 $U$ とSampling Grid を入力にとり、変換後の特徴空間 $V$ を出力します。$V_i$ は、$U_i = (x^s_i, y^s_i)$ を中心点とするサンプリングにより算出されます。
3.1. Localisation Network
入力: $U \in {\mathbb R}^{H \times W \times C}$ ($H$: 画像の縦幅、$W$: 画像の横幅, $C$: チャンネル数)
出力: $\theta$ (特徴空間を変換する関数${ \mathcal T}_\theta$のパラメータ)
この変換を $\theta = f_{loc}(U)$ とおきます。$\theta$ の次元数については特徴空間に適用する変換に依存します(例: アフィン変換なら6次元)。
$f_{loc}()$ については、最終層が $\theta$ を出力するものであれば、全結合層や畳み込み層など、間のネットワークは特に形を制限されません。
3.2. Parameterised Sampling Grid
ここでは最終的な出力空間 $V$ を作成するために、サンプリング用グリッド $G$ を求めます。$V$ 上の格子点(グリッド)を $G$ とし、グリッドの各点 $G_i$ の値 $V_i$ を求める際は、点 $U_i = (x^s_i, y^s_i)$ ($s$: source) を中心にサンプリングカーネルを適用することによって求めることとなります。
例えば2Dアフィン変換の場合、$G_i = (x^t_i, y^t_i)$ ($t$: target) とすると、点 $U_i$ は以下のような式によって求めることとなります(入力画像のどの点 $(x^s_i, y^s_i)$ を中心点とするのか定義しています)。ここで、$x^s_i, y^s_i$ は整数である必要はありません。各点についてサンプリング中心点を求める様子は図4のようになります。
$$
\begin{pmatrix}
x^s_i \\
y^s_i \
\end{pmatrix}
= {\mathcal T_\theta} (G_i) = A(\theta)
\begin{pmatrix}
x^t_i \
y^t_i \
1 \\
\end{pmatrix} =
\begin{bmatrix}
\theta_{11} & \theta_{12} & \theta_{13} \\
\theta_{21} & \theta_{22} & \theta_{23} \
\end{bmatrix}
\begin{pmatrix}
x^t_i \
y^t_i \
1 \\
\end{pmatrix} (1)
$$
図4: ${\mathcal T_\theta} (G_i)$ による変換の様子(2Dアフィン変換の場合)。$V$ 上の各点 $G_i$ について、サンプリング中心点を${\mathcal T_\theta} (G_i)$ により算出している。
本論文では入出力それぞれについて正規化された座標を使用します。
($-1 \leq x^t_i, y^t_i \leq 1, -1 \leq x^s_i, y^s_i \leq 1$)
ここで、空間的変換の種類について、例えば上式(1)の $A_\theta$ では一般的なアフィン変換を適用することで、6つのパラメータにより画像の切り取り・移動・回転・拡大・歪曲を可能にしていましたが、例えば次式
$$
A_\theta =
\begin{bmatrix}
s & 0 & t_x \\
0 & s & t_y \
\end{bmatrix}
$$
を用いれば、画像の切り取り・移動・拡大(縦横の比率は変えない)といった変換のみを可能にします。
逆に変換 ${\mathcal T_\theta}$ をもっと一般的に拡張することもでき、8つのパラメータを用いた射影変換や、Thin Plate Spline (画像上の特定の点を任意の場所にズラし、画像全体を歪曲させる変換)などを用いることができます。
実際には、パラメータ $\theta$ に関して微分可能であれば、あらゆる変換を適用することができます。$\theta$ に関して微分可能な変換を定義することで、サンプル点 ${\mathcal T_\theta} (G_i)$ から Localisation Network の出力 $\theta$ まで誤差を逆伝搬させることができるからです。
3.3. Differentiable Image Sampling
入力された特徴空間 $U$ から変換後の特徴空間 $V$ を得るために、前節でサンプリンググリッド $G$ を作成しました。これを用いて、$V$ 上の各点 $V_i$ を求める際に ${\mathcal T_\theta} (G_i)$ を、つまりは $(x^s_i, y^s_i)$ を中心点としてサンプリングすることとなります。
サンプリングについての一般式は次式のようになります。
$$
V^c_i = \sum^H_n\sum^W_m U^c_{nm} \ k(x^s_i - m; \Phi_x) \ k(y^s_i - n; \Phi_y)\
(\forall i \in [1...H'W'], \forall c \in [1...C])
$$
$H', W'$ : 出力画像の縦幅、横幅
$\Phi_x, \Phi_y$ : サンプリングカーネル $k$ のパラメータ
$U^c_{nm}$ : 入力された特徴空間のチャンネル $c$ での点 $(n, m)$ の値
$V^c_i$ : チャンネル $c$ での点 $(x^t_i, y^t_i)$ の出力値
空間変換は各チャンネルについて他のチャンネルの値によらず同様に計算されるので、各チャンネルについて同様に上式によって出力値が計算されます。
サンプリングカーネルは理論的に勾配を計算できる限りは任意のカーネルを使用することができます。以下に2つのカーネル例を紹介します。
- サンプリングカーネル例1
$$
V^c_i = \sum^H_n\sum^W_m U^c_{nm} \delta(\lfloor x^s_i + 0.5 \rfloor - m) \delta(\lfloor y^s_i + 0.5 \rfloor - n)
$$
これは、$\lfloor x^s_i + 0.5 \rfloor$ によって $x^s_i$ に最も近い整数値を求め、その値が $m$ であるときのみ値を出力するカーネルです($y$ についても同様)。要するに、点 $(x^s_i, y^s_i)$ に最も近い格子点の値を出力する最近傍サンプリングカーネルとなっています。
- サンプリングカーネル例2
$$
V^c_i = \sum^H_n\sum^W_m U^c_{nm} \ max(0, 1 - \mid x^s_i - m \mid ) \
max(0, 1 - \mid y^s_i - n \mid )
$$
近傍のグリッド4点を線形的に補完した値を出力する線形サンプリングカーネルとなっています。
これらのサンプリング機構を通るような誤差逆伝搬を考えるために、$U$, $G$ についての勾配を考えます。線形サンプリングカーネルを例にとると、以下の2式が導かれます。
$$
\frac{\partial V^c_i}{\partial U^c_{nm}} = \sum^H_n\sum^W_m \ max(0, 1 - \mid x^s_i - m \mid ) \
max(0, 1 - \mid y^s_i - n \mid ) \
\frac{\partial V^c_i}{\partial x^s_i} = \sum^H_n\sum^W_m U^c_{nm} \ max(0, 1 - \mid y^s_i - n \mid )
\left\{
\begin{array}{}
0 & \mbox{if} \ \ \mid x^s_i - m \mid \ \geq 1 \
1 & \mbox{if} \ \ m \geq x^s_i \
-1 & \mbox{if} \ \ m < x^s_i
\end{array}
\right.
$$
$\frac{\partial V^c_i}{\partial y^s_i}$ についても同様に算出することができます。
これらの式により、損失の勾配を入力された特徴空間まで伝搬させることや、サンプリンググリッドまで伝搬させることが可能になります。$\frac{\partial x^s_i}{\partial \theta}$, $\frac{\partial y^s_i}{\partial \theta}$ のそれぞれが (1) 式から計算できるので、結局、変換のパラメータ $\theta$ にまで伝搬させ、Localisation Net を学習させることが可能となります。
サンプリング関数が不連続となるので厳密には勾配ではなく劣勾配を用いることになります(劣勾配については、不連続な点に対して擬似的に勾配を定義するようなものと捉えれば問題ないかと思います)。
各出力ピクセルに対して、入力ピクセル全てを見るのではなく近傍点のみを見ることで勾配が計算できるので、この計算はGPU上で非常に効率的に実装できます。
3.4. Spatial Transformer Networks
3.1~3.3節で説明した、Localisation Net、Grid Generator、Image Sampler の3つを組み合わせることで、STNs を作成しました。この機構はCNN内の任意のタイミングで設置することができ、設置した段階で入力された特徴空間を変形します。
変換について学習された情報は Localisation Net の重みに圧縮されることとなります。また、Localisation Netの出力である、変換に用いるパラメータ $\theta$ は、物体の姿勢についての情報として扱うことができるので、このパラメータを後ろのネットワークに流すことも有用であると考えられます。
複数の STNs を異なる階層に、または同じ階層に並列に設置することもできるので、「複数のオブジェクトを個別に抽出したい」というようなタスクにも応用できます。その場合、並列に設置する STNs の数だけしかオブジェクトを検知できないので注意する必要があります。
4. Experiments
4.1. Distorted MNIST
MNISTデータに対して様々な'歪み'を加えたデータセットを用いて実験を行います。歪みの種類は以下の4種類を考えます。
- R: 回転 (Rotation)
- RTS:回転、拡大、移動 (Rotation, Scale, Translation)
- P: 射影変換 (Projective transformation)
- E: 弾性的な歪み (Elastic warping) (場合によっては STNs で完全に戻すことが不可能になる)
実験については以下の8種類の手法を比較します。
- FCN (Fully-Connected Net)
- CNN (Convolutional Neural Net)
- ST-FCN ('線形サンプリングカーネルを用いたSTNs' を導入した FCN)
- Aff (アフィン変換)
- Proj (射影変換)
- TPS (Thin Plate Spline (画像上の特定の点(本実験では16点)を任意の場所にズラし画像全体を歪曲させる変換))
- ST-CNN
- Aff
- Proj
- TPS
図5: MNIST を用いた実験結果(誤答率)の比較(左)と、(ST-CNN TPS) による空間補正の例(中心)と、 (ST-CNN Aff) による空間補正の例(右)。本実験においては(ST-CNN TPS)が非常に強力なモデルであることが分かります。中心及び右の変換例については、CNNでは正しいクラスに分類できなかったがST-CNNでは分類できた例を出しています。他の様々なデータに対する変換例を見たい方は (https://goo.gl/qdEhUu) まで。
4.2. Street View House Numbers
- SVHN データセット
- 約200k枚の、現実の画像 (家の番号を写している)。
- 各画像には 1~5文字の数字が含まれている。
- 画像により、数字が斜めになっていたり、数字部分の大きさが違います。
先行研究では、画像から数字を含む部分を 64×64 のサイズで切り出して使用していましたが、本実験ではさらに条件を緩く、128×128 のサイズで切り出し、数字以外の範囲を多く含む場合の実験も行います。
実験については以下の5種類の手法を比較します。
- Maxout CNN Goodfellow ’13
- CNN (11層のCNN)
- DRAM Ba ’15
- ST-CNN Single (11層CNNの一番最初にSTNs(4層CNN)を導入したネットワーク)
- ST-CNN Multi (11層CNNのうち、最初の4層のCNNの前にそれぞれSTNs(2層FCN)を導入したネットワーク)
図6: SVHNデータセットを用いた実験結果(誤答率)の比較(左)と、ST-CNN Multi による空間変換の例(右)。ST-CNN Single, ST-CNN Multi共に既存手法を凌駕していることが分かります。空間変換の例としては、4つのSTNsの変換全てを合わせた変換を表しています。
4.3. Fine-Grained Classification
STNsを並行に導入した場合の予測性能を測るために、鳥類データセット(CUB-200-2011)を使用します。
- CUB-200-2011
- 約6k枚の学習用画像を有しています。
- 約5.8k枚のテスト用画像を有しています。
- 200種類もの種類の鳥が含まれています。
- 各画像について、鳥のみを切り取った画像ではなく、周囲の景色もある程度含む画像。
実験については以下の9種類の手法を比較します。
- Cimpoi ’15
- Zhang ’14
- Branson ’14
- Lin ’15
- Simon ’15
- CNN 224px : ImageNetを学習させたInceptionを転移学習したもの(これだけでも先行研究を超える結果となっています)。
- 2×ST-CNN 224px : 2つのSTNsを並列に使用したもの。入力画像の解像度は224px
- 2×ST-CNN 448px : 2つのSTNsを並列に使用したもの。入力画像の解像度は448px
- 4×ST-CNN 448px : 4つのSTNsを並列に使用したもの。入力画像の解像度は448px
ST-CNNについては計算コストを大きく上昇させることなく448pxの解像度の画像を入力できます。
図7: 鳥類データセット(CUB-200-2011)を用いた実験結果(正答率)の比較(左)と、2×ST-CNN及び4×ST-CNNのそれぞれが切り取った画像(右)。(4×ST-CNN 448px)が既存手法を大きく上回る性能を発揮していることが分かります。2×ST-CNNの切り取り方を見ると、一方(赤色)のSTNsが鳥の顔周辺を切り取っているのに対し、もう一方(緑色)のSTNsが鳥の胴周辺を切り取っていることが分かります。
5. Conclusion
筆者らは、自己完結型であり他のネットワークの構造に影響を及ぼさない形で、入力画像に空間的補正を加える Spatial Transformer モジュールを導入しました。これはCNNなど既存のネットワークにそのまま接続するだけで、学習方法などは特に変更することなくCNNの精度を向上させ、state-of-the-artを実現させることに成功しました。
実験によって、MNISTデータセットをはじめ、現実のデータについても既存手法を上回る実験結果を算出し、画像認識や物体検出などの分野で応用できることが期待されます。
以上。