当記事ではOptical Flowの学習にあたってTransformerを用いた研究であるFlowFormerの論文について解説を取りまとめました。
概要
FlowFormerの論文
FlowFormerの処理概要
FlowFormerのパフォーマンス
詳細
Building the 4D Cost Volume
FlowFormerで入力される画像のサイズを$H_{I} \times W_{I} \times 3$、Siamese Encoderによって得られるFeature Mapのサイズを$H \times W \times D_{F}$とおきます。FlowFormerでは$(H, W) = (H_{I}/8, W_{I}/8)$が基本的に用いられます。
このとき2枚の画像の$H \times W \times D_{F}$のサイズのFeature Mapをチャネル方向に内積(dot-product)を計算することで、4D($H \times W \times H \times W$)のCost Volumeを得ることができます。FlowFormerにおけるここまでのEncoderの処理はTwins-SVT(ViTの一種)に基づいて計算されます。
Cost Volume Encoder
Sourceの位置$\mathbf{x} \in \mathbb{R}^{2}$におけるCost Map(4DのCost VolumeからSourceの1つの位置に対応する2Dのテンソルを抽出したもの)を$M_{\mathbf{x}} \in \mathbb{R}^{H \times W}$と定義し、以下の3つの手順によってCost Volume Encoderの処理が構築されます。
1) Cost Map Patchification
2) Patch Feature Tokenization
3) Alternate-Group Transformer(AGT)
上記の3つの処理について以下詳しく確認を行います。
Cost Map Patchification
Cost Map PatchificationではCost MapにCNNを適用し、Feature Mapを得る処理を実行します。CNNにおける畳み込み演算ではstride=2の畳み込みを3回行うことで、縦方向横方向共に1/8のサイズのFeature Mapを取得します。CNN演算によってテンソルのサイズは下記のように変化します。
\begin{align}
H \times W & \longrightarrow \frac{H}{2} \times \frac{W}{2} \times \frac{D_{p}}{4} \\
& \longrightarrow \frac{H}{4} \times \frac{W}{4} \times \frac{D_{p}}{2} \\
& \longrightarrow \frac{H}{8} \times \frac{W}{8} \times D_{p} \\
\end{align}
FlowFormerの論文では上記の処理の出力が$F_{\mathbf{x}} \in \mathbb{R}^{H/8 \times W/8 \times D_{p}}$のように定義されます。ここで$F_{\mathbf{x}}$の空間方向の値は$M_{\mathbf{x}}$の$8 \times 8$のパッチにそれぞれ対応することには注意しておくと良いです。
Patch Feature Tokenization
$F_{\mathbf{x}} \in \mathbb{R}^{H/8 \times W/8 \times D_{p}}$に対し、下記の演算を行うことでTransformerのKey$K_{\mathbf{x}}$とValue$V_{\mathbf{x}}$を作成します。
\begin{align}
K_{\mathbf{x}} &= \mathrm{Conv}_{1 \times 1}(\mathrm{Concat}(F_{\mathbf{x}},PE)) \in \mathbb{R}^{H/8 \times W/8 \times D} \\
V_{\mathbf{x}} &= \mathrm{Conv}_{1 \times 1}(\mathrm{Concat}(F_{\mathbf{x}},PE)) \in \mathbb{R}^{H/8 \times W/8 \times D} \\
F_{\mathbf{x}} & \in \mathbb{R}^{H/8 \times W/8 \times D_{p}}, \, PE \in \mathbb{R}^{H/8 \times W/8 \times D_{p}}
\end{align}
上記の演算によって得られたKey$K_{\mathbf{x}}$とValue$V_{\mathbf{x}}$に対し、下記のような演算を行います。
\begin{align}
T_{\mathbf{x}} &= \mathrm{Attention}(C, K_{\mathbf{x}}, V_{\mathbf{x}}) \in \mathbb{R}^{K \times D} \\
C & \in \mathbb{R}^{K \times D}
\end{align}
上記の$T_{\mathbf{x}} \in \mathbb{R}^{K \times D}$は位置$\mathbf{x}$ごとのテンソルであるので、全ての位置に関するテンソルを$\mathbf{T}$とすると、$T \in \mathbb
{R}^{H \times W \times K \times D}$が成立します。ここで一般的には$K \times D << H \times W$となるように$K$と$D$を定義することで、$H \times W \times H \times W$の4DのCost Volumeを$H \times W \times K \times D$へ要約を行うことができます。
Alternate-Group Transformer(AGT)
Alternate-Group Transformer(AGT)処理の概要は下図に基づいて把握すると良いです。
上図から空間方向のAttention処理であるIntra-cost-map Self-Attentionと、潜在ベクトル方向のAttention処理であるInner-cost-map Self-AttentionによってAlternate-Group Transformer(AGT)が構成されることが確認できます。Alternate-Group Transformer(AGT)は基本的に$T \in \mathbb
{R}^{H \times W \times K \times D}$を入力とするSelf-Attentionであることから、出力されるテンソルのサイズは入力と同様に$H \times W \times K \times D$となります。また、この出力はCost Memoryと称され、$H \times W \times K$のトークンそれぞれについて$D$次元のベクトルを持つと解釈することができます。
Cost Memory Decoder for Flow Estimation
当項では前項のCost Volume Encoderによって得られたCost MemoryのDecodeとOptical Flowの推定について確認を行います。
Cost Memory Decoder
\begin{align}
Q_{\mathbf{x}} &= \mathrm{FFN}(\mathrm{FFN}(\mathbf{q}_{\mathbf{x}}+PE(\mathbf{p}))) \in \mathbb{R}^{1 \times D} \\
K_{\mathbf{x}} &= \mathrm{FFN}(T_{\mathbf{x}}) \in \mathbb{R}^{K \times D} \\
V_{\mathbf{x}} &= \mathrm{FFN}(T_{\mathbf{x}}) \in \mathbb{R}^{K \times D} \\
\mathbf{q}_{\mathbf{x}} &= \mathrm{Crop}_{9 \times 9}(M_{\mathbf{x}}, \mathbf{p}) \in \mathbb{R}^{9 \times 9} \\
\mathbf{p} &= \mathbf{x} + \mathbf{f}(\mathbf{x}) \in \mathbb{R}^{2}
\end{align}
上記のように定義された$Q_{\mathbf{x}}$、$K_{\mathbf{x}}$、$V_{\mathbf{x}}$を用いて下記で表されるCross-Attentionの計算が行われます。
\mathbf{c}_{\mathbf{x}} = \mathrm{Attention}(Q_{\mathbf{x}}, K_{\mathbf{x}}, V_{\mathbf{x}}) \in \mathbb{R}^{D_{c}}
Flow Estimation
下記のような演算に基づいて点$\mathbf{x}$におけるflowの差分の$\Delta \mathbf{f}(\mathbf{x})$が計算されます。
\begin{align}
\Delta \mathbf{f}(\mathbf{x}) &= \mathrm{ConvGRU}(\mathrm{Concat}(\mathbf{c}_{\mathbf{x}}, \mathbf{q}_{\mathbf{x}}), \mathbf{t}_{\mathbf{x}}, \mathbf{f}(\mathbf{x})) \\
\mathbf{c}_{\mathbf{x}} & \in \mathbb{R}^{D_{c}}
\end{align}