はじめに
AAAI2018より以下の論文
[1] S. Yan, et. al."Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition"
のまとめ
公式らしきコード:
https://github.com/yysijie/st-gcn
arXiv:
https://arxiv.org/abs/1801.07455
既にyukihiro domae氏のslide share などにまとめ記事が存在するが、後学のため敢えてまとめ。
https://www.slideshare.net/yukihirodomae/spatial-temporal-graph-convolutional-networks-for-skeletonbased-action-recognition
概要
- 骨格推定からの行動認識の分野にグラフ畳み込みを持ち込んだ最初の論文
- 時間方向と空間方向に対して畳み込むが、空間方向はグラフ畳み込みを行う
以下の図のような流れ。前提として、videoデータから骨格推定モデル等を用いて、各フレームに対して骨格推定を行っている。本手法ST-GCNsはこの時空間なデータを入力とし、推定した行動のクラスを出力とする。
先行研究
書きかけ
Spatial Temporal Graph ConvNet
本論文のmainである時空間なgraph convについて。
Skelton Graph Construction
まずグラフの構成から。
ノード V を以下で定義する。T フレームからなるvideoと N 関節からなる骨格を考え、
V = \{v_{ti} | t=1, \dots , T, i=1, \dots, N \}
とする。実際には ST-GCN にはノードの3次元空間上(?)の座標
F(v_{ti})
を入力する。
次に edge だが、これは同じフレーム内の隣接する関節との edge $E_S$ と同じ関節のフレーム間にわたる edge $E_F$ からなる。
\begin{eqnarray*}
E_S &=& \{ v_{ti} v_{tj} \ | \ (i,j) \in H \} \\
E_F &=& \{ v_{ti} v_{(t+1)i} \} \\
\end{eqnarray*}
今回取り扱う graph はこれらノードとエッジ E から構成される。
G = (V, E)
グラフ畳み込み
以下、本論文の主題、「グラフ畳み込みについて」
画像における畳み込み
まず一般的な画像に対する畳み込みの基礎から。
$K$ :kernel size
$f_{in}$ :入力のfeature map
${\bf x}$ :画像上の位置。例えば横に3 pix, 縦に8 pixなど。
${\bf p} :Z^2 \times Z^2 \to Z^2$ :サンプリングの関数。例えば x の 2 pix右隣かつ 3pix下隣、など。
${\bf w} : Z^2 \to \mathbb{R}^c$ :重み関数
とする。stride が 1 を考え、入力の feature map におけるあるチャンネル c に注目する。出力側を1 チャンネルとすれば、出力側の位置 ${\bf x}$ の値は
f_{out} ({\bf x}) = \sum^K_{h=1} \sum^K_{w=1} f_{in} ({\bf p} ({\bf x}, h, w)) \cdot {\bf w} (h, w) \tag{1}
となる。ただし、重みは同じ画像において単一の値。つまりアインシュタインの縮約記法で書くと xy, c -> cxy って感じか?
以下ではこれをベースにして、グラフ畳み込みに適応するように変形する。
グラフ畳み込みのサンプリング関数
上記例の画像であればサンプリング関数 p は
{\bf p}({\bf x}, h, w)) = {\bf x} + {\bf p}' (h, w)
といった感じで、画像上の位置 x の縦 h 隣、横 w 隣となる位置をサンプリングすればよいが、グラフにおいても似たことを行う。
node $v_{ti}$ に対し、ノード間距離が D 以内となるノードを D 隣と定義する。
B(v_{ti}) = \{ v_{tj} \ | \ d(v_{tj}, v_{ti}) \} \leq D
サンプリング関数はこの B を満たすノードだけを取得する $B(v_{ti}) \to V$
{\bf p} (v_{ti}, v_{tj}) = v_{tj} \tag{2}
グラフ畳み込みの重み関数
画像であれば (c, K, K)の大きさのtensorだが、グラフの場合は工夫が必要。
まずあるノードの隣 $B(v_{ti})$ を K のタイプに分ける。
l_{ti} : B(v_{ti}) \to \{0, \ldots, K-1 \}
これは例えば当該ノード自身とノードの1つ隣に分けるなら2個だし、当該ノード自身、根元側の隣、末端側の隣に分けるなら3個となる。
重みは、分けたタイプごとに固有の値とする。
{\bf w}(v_{ti}, v_{tj}) = {\bf w}' (l_{ti} ( v_{tj}))
なお画像同様、出力側の c チャンネル毎に異なる値とするので、入力側のあるノードのペアに対して c 個値をとる。
{\bf w}(v_{ti}, v_{tj}) : B(v_{tj}) \to R^c
空間方向のみのグラフ畳み込み
以上からサンプリング関数と重み関数が定まったので、グラフ畳み込みは以下のようになる。
\begin{eqnarray}
f_{out} ( v_{ti}) &=& \sum_{v_{tj} \in B(v_{ti})} \frac{1}{z_{ti}(v_{tj})} f_{in} ( {\bf p}(v_{ti}, v_{tj})) \cdot {\bf w}(v_{ti}, v_{tj}) \tag{4} \\
&=& \sum_{v_{tj} \in B(v_{ti})} \frac{1}{z_{ti}(v_{tj})} f_{in} ( v_{tj}) \cdot {\bf w}(l_{ti} ( v_{tj})) \tag{5}
\end{eqnarray}
$1 \ / \ Z_{ti} (v_{tj})$ は正規化項。
まとめた表記方法
$A$ :隣接行列
$I$ :単位行列
$\Lambda^{ii} = \sum_j (A^{ij} + I^{ij})$ :隣接行列+単位行列を行に関して足し合わせたもの
として
{\bf f}_{out} = {\bf \Lambda}^{-\frac{1}{2}} ({\bf A} + {\bf I} ) {\bf \Lambda}^{-\frac{1}{2}} {\bf f}_{in} {\bf W} \tag{9}
${\bf \Lambda}^{-\frac{1}{2}} ({\bf A} + {\bf I} ) {\bf \Lambda}^{-\frac{1}{2}}$ の部分が正規化された隣接(+自分自身)行列。
ネットワークのアーキテクチャの学習方法
- scaleに対して正規化するため、入力直後にbatch normする
- residualなblockで構成・・・時空間で畳み込んだもの(ST-GCN unit)に、その入力を足す
- ST-GCN unitからなるlayerは9つ
- channel数は、最初の3 layerで64, 次の3 layerで128, 最後の3 layerで256
- 時間方向のkernelのサイズは9
- ST-GCN unitの最後に0.5のdropoutを行い汎化性能を上げる
- 4th layerと7th layerの後にstirde=2で小さくする
- 最後のlayerの256 channel の feature map に対して GAP をし、256個の値を得る
- loss:softmax cross-entropy
- SGDでlearning rate 0.01から始め、10epochごとに0.1 で decay
- 以下の2つのaugmentationを行う
1)affineとscale変化を1つのsequenceの全てのdataに対して行う -> view point自体が動くことを真似る
2)元の骨格から一部のみを使用する
実験と結果
NTU RGB+D datasetによる定量的評価
以下はNTU RGB+D datasetの 60 class の方で当時のSOTA modelと比較したもの。
poseのみを入力とする他のモデルと比較すると、当時としては最高性能。
Kinetics datasetによる定量的評価
以下は Kinetics datasetにおいて他のモデルと比較したもの。
ablation study
以下が ablation study。
上段から、Baseline TCNはtemporalなconvのみ、Local Convolutionは空間方向にも畳み込むが重みは共有しない。その下4つは「隣接」の定義による違い。Uni-labelingはdistance=0 つまり当該のノードのみに対して畳み込む。Distance partitioningは隣のノードも含むが、根元側と末端側を同一として扱う。Spatial Configurationは根元側と末端側を区別する。
最後のST-GCN + Impはlearnable edge importance weighting。