LoginSignup
7
3

More than 1 year has passed since last update.

【計算量削減】3DCNNで適応的に時間特徴を圧縮するSGS(Similarity Guided Sampling)の紹介

Last updated at Posted at 2021-12-06

2021年のディープラーニング論文を1人で読むAdvent Calendar7日目の記事です。今日は動画分析で主に使う3DCNNの計算量削減について紹介します。CVPR2021に採択されています。ファーストとセカンドオーサーがドイツのボン大学の方です。

3DCNNは需要はあるのですが、計算コストが膨大すぎるというどうしようもないデメリットがあります。例えば、1サンプルで200フレームあるとしたら、画像で言う所の200枚の計算が必要になります。Conv2Dは計算量がカーネルサイズの2乗に比例しますが、Conv3Dになると3乗に比例するので、画像の200枚以上に計算量は膨大です。動画の200フレームは大したことなくて、30FPSならたかだが5~6秒の内容です。フレーム数に対してスケールできることがかなり求められています。

この論文では、動画の内容(動きの速い、動きがあんまりない)に注目し、CNNの中間でフレームをグルーピングしたり、不要なフレームを捨てたりするサンプリングの研究です。タイトルにもあるとおり「Adaptive Temporal Feature Resolutions:適応的な時間特徴の解像度」なので、時間(フレーム)方向に特徴量をいい感じに圧縮してくれるようなものと理解すればOKです。ただし、ネットワークの中でサンプリングするので微分可能でなければいけません。

最初に断っておきますと、この論文読解するの結構しんどいです。自分もコードとにらめっこして、簡単なサンプル作ってようやく「わかったようなわからないような」なレベルなので、無理に理解しなくていいです。ただ、動画をやろうとすると確実に必要になるテーマの論文で、しんどいからといって逃げてもいられないので、頑張って読んでいきます。微分可能なサンプラーとは一体何でしょう。

問題設定

画像に代表される2DのCNNでは、画像の特徴量が、

$$(N, C, H, W)$$

というshapeのテンソルで表されました(表記はPyTorchにしています)。ここで$N$はサンプルサイズ、$C$はチャンネル数、$H$は縦方向の解像度、$W$は横方向の解像度です。

もしこれが3Dとなると、時間方向の$T$というパラメーターが加わります。PyTorchのConv3Dでは、

$$(N, C, T, H, W)$$

というshapeを入出力で使っています。今フレーム数が多すぎるので、$T$をより小さな$B'(B'<T)$という値に変えたいのです。つまり、

$$(N, C, B', H, W)$$

としたいのです。

StrideやPoolingで削ればいいじゃん

最も簡単なやり方は、時間方向にStrideやPoolingをかけて、時間解像度を半分や1/4にしてしまう方法です。

ただ、これは固定で圧縮するやり方なので、動画の内容によってはうまくいきません。例えば動きの大きい動画では情報を削りすぎてしまうの対し、動きがゆっくりした動画では情報があまり削れません。動画の内容に応じて削るフレーム数を調整する機構があればとても便利です。

SGSの考え方

07_01.png

この考え方のもとに作られたのがSGS(Similarity Guided Sampling)です。上は動きの少ない動画、下は動きの多い動画です。動きの少ない動画では持っておくフレーム数は少なくていいのですが、動きの多い動画ではいっぱいフレームを持っておく必要があります1。$B'$の値は固定ではなく、内容ごとに可変であることが望ましいです。

「微分可能なサンプラーを作るだけでも大変なのに、可変フレーム数なんてどうするんだ」って思うかもしれませんが、いきなりフレーム数を可変の$B'$に落とさずに、$B'\leq B < T$なる$B$に一度落とします。ここで$B$は固定値で、$T\to B$への操作は、フレームのグルーピングです。

$B$の各フレーム(グループ)について、有効なフレーム情報を持っていなければそのフレームは落とします。こうすることでフレーム数が可変の$B'$まで落とされるわけです。これがSGSの基本的な考え方です。これをネットワーク内で実装します。

SGSの理論

類似度の空間へ

まずはフレームのグルーピングをするための類似度を計算します。この類似度は$\mathcal{Z}\in\mathbb{R}^{T\times L}$とします。ここで$L$は類似度の空間の次元数です。

ある時点$t$のフレームを$\mathcal{I}_t\in \mathbb{R}^{C\times H\times W}$とましょう。実装的に書けば$\mathcal{I}$は、

$$(N, C, T, H, W)$$

のshapeを持つテンソルです。空間方向にGlobal Average Poolingをかけます。すると、

$$(N, C, T)$$

というshapeになります。$C, T$の軸を入れ替え、複数のConv1Dレイヤー(実装上ではFCでも代用できます)をかけ、最終的な出力チャンネル数を$L$にすると、

$$(N, C, T)\xrightarrow{swap} (N, T, C) \xrightarrow{conv} (N, T, L)$$

というshapeになります。これが求めたい類似度$\mathcal{Z}$になります。

類似度のbin

ここからレイヤーをまとめる作業に入ります。類似度$\mathcal{Z}$の大きさを考えます。フレーム$t$における類似度$\mathcal{Z}$の大きさ$\Delta_t$を、

$$\Delta_t = |\mathcal{Z}_t | $$

とします。ここで$|\cdot|$は任意の距離関数です。具体的に何かは言及ありませんでした。コードでは複数の関数が提示されていましたが、その中にL2やL1があったので、L2やL1で認識しておけばいいのではないかと思います。ここでの距離関数は特徴量の次元$L$方向を集約します。つまり、

$$\mathcal{Z}:(N, T, L) \xrightarrow{dist} \Delta : (N, T)$$

というshapeになります。次にこの$\Delta$を最大値で正規化します。ここらへんは公式実装を参考にしているため論文本文とは少し表記が異なります。正規化された$\Delta$を$\Delta^{norm}$とすれば、

\Delta^{norm}_t = \Delta_t \frac{2B}{\Delta_{max}}, \qquad \Delta_{max}=\max(\Delta_1, \cdots, \Delta_t)

今やりたいのは、$T$個あるフレームを、$B$個のbinに集約することです。binというとヒストグラムの棒の数をイメージするとわかりやすいかもしれません。「$T\to B$のマッピングをどのように作りますか?」という問題を解きたいのです。

このマッピングは最終的に$\Psi$というカーネルに格納されます。$\Psi$は$(N, B, T)$というshapeになります。$\Psi$と、フレームの特徴量$\mathcal{I}$の行列積を計算することで、フレーム数を$B$や$B'$に圧縮できるというわけです。

この$\Psi$の計算のためには各binsの中心値を使います。$B=b$のときのbinの中心を$\beta_b$とすると、$\beta_b$は次のように推定できます。

$$\beta_b = (2b+1)\frac{\Delta_{max}}{2B}\qquad \forall_b\in(0,\cdots, B-1)$$

$\beta$はnp.arange(B)的に再サンプリングしたものなので、$(N, B)$というshapeを持ちます。これをカーネルの計算に使っていきます。

微分可能なbinsのサンプリング

微分可能なサンプリングとは、実装上はカーネル$\Psi$の計算です。これを関数として表記すれば、

$$\Psi(\Delta_t^{norm}, \beta_b) $$

となります。$\Delta^{norm}$が$(N, T)$、$\beta$が$(N, B)$という異なるshapeを持つため、実装上は互いに軸を追加してブロードキャストして計算します。この結果、カーネル値$\Psi$は$(N, B, T)$というshapeになります。直感的には$N, B$それぞれの軸で二重のforループをして計算しているとも捉えられます2

このカーネル関数ですが、論文では2つ用意しています。

クロネッカーのデルタ関数δ

$$\Psi(\Delta_t^{norm}, \beta_b) = \delta(|\Delta_t^{norm}-\beta_b|)$$

クロネッカーのデルタ関数というとぎょっとしますが、実装上は下のコードでいいです。

kernels = torch.zeros_like(bins) # [N x B x T]
bin_r = bin_sizes / 2 # [N x B x T]
kernels[torch.abs(distances - bins) <= bin_r] = 1.0 # [N x B x T]

ここでのbinsは$\beta$, bin_sizesは$\Delta_{max}/B$、distancesは$\Delta_t^{norm}$を表します。shapeがすべて$(N, B, T)$で統一されているのは、ブロードキャストして揃えているからです。

なぜわざわざデルタ関数としているかというと、微分計算を定義する必要があるからです。単なる代入操作だと微分ができないのかと思います。コードでも別途backpropを定義していました。実装上は代入でいいんだけど、微分計算の数学的な裏付けのためにデルタ関数としているのだと自分は解釈しました。

線形のサンプリングカーネル

2つ目のサンプリングカーネルは、

$$\Psi(\Delta_t^{norm}, \beta_b) = \max(0, 1-|\Delta_t^{norm} - \beta_b|)$$

とするものです。デルタ関数がOne-hotに近いものだったのに対し、こちらはソフトなサンプリングになります。数式的にはこちらのほうがわかりやすいかもしれません。

これらのカーネルの微分について論文に記載されていますが、煩雑になるので省略します。気になる方は論文を見てください。カーネルの実装は公式実装のここにあります。自分はコード見てなにやっているのかようやく理解できました。

使われていないbinを削る

デルタ関数・線形カーネルなどのカーネル計算を通じ、カーネル値$\Psi$が求められましたが、計算結果使われていないbinも相当存在します。$\Psi$は

$$(N, B, T)$$

というshapeですが、使われていないbinを削ることで、

$$(N, B', T), \qquad B'<B$$

に削ることができます。これが適応的なサンプリングとなっています。

実装上は、カーネル値がOne-hotに近い形で得られるため(線形カーネルだともう少し数値にばらつきがありますが)、カーネル値のnon-zeroをスライスするなど、比較的簡単な行列演算で可能です。サンプリングをカーネル関数に帰着させたのがこの論文の大きな特徴でしょう。

特徴量の集約

SGSのサンプリングの特徴マップを可視化したものがこちらです。

07_02.png

最初は32個のフレームがありましたが、これをSGSによって4個のbinに集約しています。集約された特徴量は前後のフレームをまとめたもので、集約前と比べて大きく乱れるということはありません。

この集約の実装は、集約された特徴量を$\mathcal{O}_b$とし、

\mathcal{O}_b=\sum_{t=1}^T\mathcal{I}_t\Phi(\Delta_t^{norm}, \beta_b)

とします。実装では$\mathcal{I}$を$(N, T, K), K=CHW$という3階テンソルに変形し、$(N, B', T)$のshapeのカーネルと、torch.bmm(kernels, input)のように計算すれば終わりです。この結果は$(N, B', K)$となります。$K$の軸をreshapeやswapで戻せば、$(N, C, B', H, W)$という時間解像度のみダウンサンプルされたテンソルとなります。後は3DCNNに入れればOKです。

これで理論は終わりですが、なかなかしんどかったですね。次は簡単なコードでSGSの処理をトレースしてみましょう。

SGSをコードでトレースする

SGSの処理を簡単なコードでトレースしてみました。これを見ていくと大まかな流れがわかると思います。このコードは公式実装をベースに簡単な例にあてはめたものです。

類似度Z

今類似度の$\mathcal{Z}:(N, T, L)$について、決定的に数値を与えてみます。$N=1, T=8, L=2$としています。

import torch

# 類似度Zを与える
zt = torch.tensor([[0.1, 0.5],
                   [0.2, 0.3],
                   [-0.2, 0.1],
                   [-0.5, -0.6],
                   [1.2, -0.3],
                   [0.6, 0.8],
                   [1.3, 1.5],
                   [1.8, 1.1]]).unsqueeze(0) # [N x T x L]
print(zt.shape) # torch.Size([1, 8, 2])

⊿とβの計算

次に$\mathcal{Z}_t$の大きさ$\Delta_t$を考えます。ここではL2ノルムとしています。

# ⊿tの計算
delta_t = torch.sqrt(torch.sum(zt**2, -1)) # [N x T]
print(delta_t.shape) # torch.Size([1, 8])
print(delta_t) # tensor([[0.5099, 0.3606, 0.2236, 0.7810, 1.2369, 1.0000, 1.9849, 2.1095]])

最大値を計算し、$\Delta_t$を正規化します。次元は変わりません。ここでbinsの個数$B=4$とします。

# ⊿tの正規化
num_bins = 4
max_distances = delta_t.max(dim=1)[0].view(-1, 1) # [N x 1]
distances = delta_t * (2 * num_bins / max_distances)
print(distances) # tensor([[1.9337, 1.3674, 0.8480, 2.9619, 4.6909, 3.7924, 7.5276, 8.0000]])

binを作ります。bins_muはbinの中心で、数式では$\beta$にあたるものです。カーネル計算で用います。

# binの作成
step_size = max_distances / num_bins # [N x 1]
mu_start = step_size / 2 # [N x 1]
index_tensor = (
    torch.arange(start=0, end=num_bins).view(1, -1).expand(step_size.shape[0], -1)
) # [N x B]
bins_mu = index_tensor * step_size + mu_start # [N, B]
print(bins_mu) # tensor([[0.2637, 0.7911, 1.3184, 1.8458]])

カーネルの作成

いよいよカーネルを作ります。目的のshape$(N, B, T)$にブロードキャストします

# ブロードキャスト
bins = bins_mu.unsqueeze(2) # [N x B x 1]
bins = bins.expand(-1, -1, distances.shape[1]) # [N x B x T]

distances = distances.unsqueeze(2)  # [N x T x 1]
distances = distances.expand(-1, -1, bins.shape[1])  # [N x T x B]
distances = distances.permute(0, 2, 1)  # [N x B x T]

bin_sizes = step_size.unsqueeze(2)
bin_sizes = bin_sizes.expand(-1, bins.shape[1], bins.shape[2])  # [N x B x T]

デルタ関数のカーネルで求めます。

# Delta kernel
kernels = torch.zeros_like(bins) # [N x B x T]
bin_r = bin_sizes / 2
kernels[torch.abs(distances - bins) <= bin_r] = 1.0
print(distances - bins)
print(kernels)

ここでのdistances-binskernelsの値は以下のようになります。

tensor([[[ 1.6700,  1.1037,  0.5843,  2.6982,  4.4272,  3.5287,  7.2639,
           7.7363],
         [ 1.1427,  0.5763,  0.0569,  2.1709,  3.8998,  3.0013,  6.7366,
           7.2089],
         [ 0.6153,  0.0489, -0.4704,  1.6435,  3.3725,  2.4739,  6.2092,
           6.6816],
         [ 0.0879, -0.4785, -0.9978,  1.1161,  2.8451,  1.9465,  5.6818,
           6.1542]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0.]]])

カーネル値がOne-hotっぽくなるというのはこういうことです。もう少し$\Delta_t$の間隔が短ければ複数レイヤーをまたいでグループ化できたかもしれません。

使われていないbinを落とす

ここでのカーネルは$(N, B, T)$ですが、最初の1行目がすべて0なので削ることができます。「使われていないbin」というのは1行目のような例です。ここでは$B'=3$とし、カーネルを$(N, B', T)$とします。

# 0のbinを落とす
active_bins = kernels.sum(dim=2)  # [N x B]
max_active_bin = 0
# fixme: Check the full vectorized no loop way
for n in range(kernels.shape[0]): # loop over N
    i = active_bins[n].nonzero().view(-1)
    i_prime = torch.arange(i.shape[0])
    print(i) # tensor([1, 2, 3])
    print(i_prime) # tensor([0, 1, 2])
    kernels[n][i_prime] = kernels[n][i]
    kernels[n][i_prime[-1] + 1 :] = 0.0
    if max_active_bin < i.shape[0]:
        max_active_bin = i.shape[0]

kernels = kernels[:, 0:max_active_bin]  # [N x B' x T]
print(kernels)
# tensor([[[0., 0., 1., 0., 0., 0., 0., 0.],
#          [0., 1., 0., 0., 0., 0., 0., 0.],
#          [1., 0., 0., 0., 0., 0., 0., 0.]]])

1行目だけ削ることができました。これは公式コードからのほぼコピペですが、ここがforループなしで実装できるといいですね。

特徴量の集計

最後にフレームの特徴量をbinで集計し、マージします。ここではフレームの特徴量を乱数で与え、

input = torch.randn((1, 32, 8, 16, 16)) # [N x C x T x H x W]

とします。ここでは$N=1$, チャンネル数$C=32$, フレーム数$T=8$, 縦横解像度$H=W=16$としています。この値は適当なので「モデルの中間層の出力なんだな」ぐらいに見てください。

input = input.swapaxes(1, 2) # [N x T x C x H x W]
input = input.flatten(start_dim=2) # [N x T x K]  (K=CHW)
output = torch.bmm(kernels, input) # [N x B'x K]
print(input.shape) # torch.Size([1, 8, 8192])
print(output.shape) # torch.Size([1, 3, 8192])

inputを$(N, T, K)$というフォーマットに変え、カーネルとtorch.bmmします。shapeを確認すると、input:(1, 8, 8192)output:(1, 3, 8192)となっていることがわかります。これはフレーム数が$T=8\to B'=3$に圧縮できたことを示します。あとはview(reshape)などで元の5階テンソルに戻せばOKです。

これら一連の操作は微分可能なので、SGSをモデルの中で使えば、時間解像度を圧縮し計算量を下げつつ訓練ができるというわけです。

実際にマッピングがうまくできているかという点ですが、inputとoutputについてinput[0, :, 0]のようなスライスをしてみます。

print(input[0, :, 0])
# tensor([-0.1115,  1.3532, -1.0616, -0.4676, -0.1187, -0.5522,  1.8831, -1.0669])
print(output[0, :, 0])
# tensor([-1.0616,  1.3532, -0.1115])

カーネルが3, 2, 1の順番でOne-hot形式だったので、インデックス通りにマッピングされているのがわかります。フレームの順番がマージ後に保証されないのが怪しそうですが、これでも精度は出ているのでそこまで大きな問題ではないのでしょう。

実験結果と評価

フレームレートと入力フレーム数

論文に戻ります。この論文では、Mini-Kinetics, Kinetics-400, Kinetics-600, Something-Something-V2, UCF-101, HMDB-51のデータセットについて実験しています。モデルはR(2+1)D, I3D, X3D, modified 3DResNetについて実験しています。結局精度が一番良かったのが3DResNetでした。これはMini-KineticsのValデータでの精度です。

07_04.png

ATFR(Adaptive Temporal Feature Resolutions)とはSGSのことです3。いずれもATFRを入れるとGFLOPsが下がることが確認できます。ATFRを入れても精度もほとんど下がらず、一番精度が良かったのが3DResNet18に64フレーム入れてstride2で回す例でした。ここでのstrideとは、Conv層のstrideではなく、動画のフレームレートを30fpsから15fpsに下げるというフレームレベルの間引きではないかと思われます。

stride2で32フレーム入力と、stride1で64フレーム入力は、参照している動画の時間が同じです。前者よりも後者のほうが精度が上がるのは、前者は入力時に間引きし情報を捨てているからです。GFLOPsに注目しましょう。ATFRがない例では、入力フレーム数に連動するようにGFLOPsが増加します。一方で、ATFRがある例では、入力フレーム数が2倍になっているのにGFLOPsは1.5倍(3DResNet)になっています。これはATFRが冗長なフレームをマージしているためです。

ATFRでGFLOPsの余裕が出てきたら、例えばフレーム数を128、strideを2にすればもっと精度上がると思われます。ところで、時間軸を圧縮すればいいのだったら、SGSのような面倒なことやらなくても

$$(N, C, T, H, W) \xrightarrow{conv / dilated_conv(k, 1, 1)} (N, C, T, H, W) \xrightarrow{swap}(N, T, C, H, W)\xrightarrow{conv(1, 1, 1))}(N, B, C, H, W)\xrightarrow{swap} (N, C, B, H, W)$$

みたいにニューラルネットワークのモジュールで対応できそうな気がしますが、こんな誰でも思いつきそうなものはおそらく誰かがやっているのでしょう。自分だったらSGS入れる前にこれやってみるかな。

SGS+3DResNet-18のアーキテクチャー

なお、ATFRを入れたときの3DResNet-18のアーキテクチャーは次の通りです。

07_03.png

こんな小さいモデルで計算量厳しい言うくらいなので、動画はなかなか修羅の道ですね。SGSはResBlock2の後に入れています。

ハイパーパラメータ

07_05.png

  • Table 1が類似度の$\mathcal{Z}_t, \Delta_t$の計算。他にも角度やスフィアの形状の類似度をやったが、これは指しすぎで単にベクトルの大きさを取るのが一番良かったとのこと。
  • Table 2がカーネル関数。クロネッカーのデルタの⊿より、線形カーネルのほうがいいらしい(デルタ関数いらんやん)。
  • Table 3が類似度の空間の$L$の次元。8が一番良く、ここを深くしすぎても逆効果

07_06.png

Mini-Kineticsによる$B$の数の比較。$B$の数を下げすぎるのもよくはなく、32ぐらいがちょうどいいとのこと。しかし、$B=32$でも実際に有効なbinsはそこまで多くなく、

07_07.png

有効なbins($B'$)はせいぜい20弱でした。最初から$B$を減らすと精度が下がるのに、ネットワークの途中で落とすと結構落とせるのが面白いです。

より大きなデータセットへ

次はKinetics-400,-600という大規模なデータです。

07_08.png

左が-400、右が-600です。400のケースでも600のケースでも、XTFRにより半分強ぐらいにGFLOPsが落とせています。-600のケースは計算量が地獄ですね。×30ってのがやばい。

07_09.png

Kinetics-400のX3D-Sでも、訓練時間がATFRなしで131時間、ATFRありだと121時間となかなか生々しい計算量でした。ただ推論速度はなしで2834fps, ありで4295fpsと爆速でした。もしかしたらモデルの大きさの割に訓練時間が長くなるような設計なのかもしれません。

まとめと感想

この記事では、適応的な時間特徴の解像度のサンプリングという3DCNNの計算量を削減する取り組みを紹介しました。微分可能なサンプリングをカーネルで実装するのは唸らされました。時間解像度を下げるのは高速化において正しいアプローチでしょう。ただ、やり方がSGSみたいに面倒なものを本当に使う必要があるのかは、自分自身は疑問として残りました。

もっとネットワークの構造でサクッと時間解像度を削減したほうが(2DCNNのBottleneck層を時間軸に適用するようにしたほうが)いいのではないかなと、半信半疑なところはあります。もっというと適応的にフレーム数を削りたいなら、前処理で動画のエンコーダーで可変フレームレートにするのもいいかもしれません。これは素人の疑問なので、多分うまく行かない or 先行研究があるのだと思います。ここらへんがいまいち腑に落ちないぶん、3DCNNがまだ発展途上ということなのでしょう。

告知

このアドベントカレンダーが本になりました!
https://koshian2.booth.pm/items/3595424
Amazonでも扱いあります詳しくは👉 https://shikoan.com


  1. 潜在的な特徴量の空間の時間軸の次元なので、厳密には動画のフレーム数とは特徴量の次元が異なりますが、わかりやすいのでフレーム数と呼んでおくことにします。 

  2. 実装上、二重のforループで書くとものすごく遅くなるので、ブロードキャストして書いています 

  3. ATFRの実装例としてのSGSなので、SGSと表記すればよかったと思うのですが、なぜかATFRと書いているんですよね。もしかしたらATFRの例がない(これまでの例は静的なサンプリングだった)から、「うちが作ったんだ」的なニュアンスを込めてATFRと書いたのかもしれません。読む側にとっては「SGSとATFRって違うニュアンスで使っているのかな」と迷うところではあります。 

7
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
3