Depth Anything V3 の backbone multi-view - local/globalでcross-view mixingを起こす
これは 生成AIアドベントカレンダー16日目 に投稿させてもらった depthanything3 の話 の続きになります。
ここではその続きとして 「backboneでmulti-viewをどう扱ってるか」 にフォーカスしてみていこうかと。
今回はそこだけに絞って、実装ベースで深掘りします。
0. 今日の結論(TL;DR)
Depth Anything V3のmulti-view対応は、専用のcross-attention層を新たに足して実現しているわけではありません。DinoV2の仕組みを重みを極力利用できる形でびっくりするような工夫をしていました。
同じViTブロック([B, N, C])を使い回したまま、
ブロックに入れる直前にトークンをreshapeして、
view間のSelf-Attentionが起こる状況を作っています。view間でcross-attnのような振る舞いをするようにテンソルを組み替えて使ってます。
補足:
ViTは[B,N,C]のテンソルを入力とします。
B次元はバッチ次元なので何であっても、その重みで動きます。
N次元は空間次元のH,Wをパッチサイズに区切ってトークングリッドにしたものを
真っ直ぐにしただけなので、ここも何であってもその重みで動きます。
-> よって第一次元、第二次元の次元数は設計としてはなんでもOKです(メモリが許す限り)
-> 最終次元はattentionのd_modelなので、ここはViTのサイズによって固定です。
このことを前提として・・・
入力: x = [B, S, N, C] (S: view数, N: patch token数)
local : [B, S, N, C] -> [B*S, N, C] (view独立)
global : [B, S, N, C] -> [B, S*N, C] (view混合 = 実質cross-view mixing)
ここでやってるのはlocalはmultiviewをバッチ次元として扱ってます。なのでlocalと呼ばれる処理では、各viewは互いに見えない状態です(テンソルはバッチ次元で基本的に常に独立してる)。
対してglobalではパッチトークンの次元を拡張することでmultiviewをそこに格納。これは各viewが互いを影響し合うことが可能になります。self-attnなのに各ビューが注意を向けることができてmultiviewの整合性を高める形で(整合性を高めるかどうかは損失の設計次第)、学習が可能です。
しかも、常時globalにするのではなく、後半ブロックだけ & 交互にglobal という「現実的な妥協」が入っています。
注意重みはCのd_modelのみでパラメーター数が決定できます。
ただただ考えた人を尊敬します。
0.1 参照資料と、追うべきファイル
実装を追うなら、最低限ここだけ見れば追えます:
- multi-view対応DINOv2ラッパー:
src/depth_anything_3/model/dinov2/dinov2.py - local/global切替の本体:
src/depth_anything_3/model/dinov2/vision_transformer.py - ViTブロック(ここは基本そのまま):
src/depth_anything_3/model/dinov2/layers/attention.pysrc/depth_anything_3/model/dinov2/layers/block.py
- RoPE(2D rotary):
src/depth_anything_3/model/dinov2/layers/rope.py - 典型config:
src/depth_anything_3/configs/da3-large.yaml - head側(viewは融合しない):
src/depth_anything_3/model/dpt.pysrc/depth_anything_3/model/dualdpt.py
1. “single plain transformer” の意味
Depth Anything V3が掲げる “Single plain transformer” は、ざっくり言うと:
• multi-view用の新規クロスアテンション層は追加しない
• 入力トークンの扱い(畳み方)でmulti-viewを成立させる
*(multi-view用の新規クロスアテンション層は追加しないで、既存のTransformer Block(self-attn + MLP)をそのまま使い、入力トークンのreshape(local/global)だけでcross-view mixingを起こす)
この設計だと、DINOv2系の 事前学習重み(ViTとして自然な形) を比較的素直に流用できます。
そして実装も「どこで何が起きてるか」を追いやすい。
1.1 backboneのI/O shape(multi-viewの “API”)
DA3のbackboneは 最初からmulti-viewの形 で受け取ります。
入力: x = [B, S, C, H, W] (S: view数 / video window長)
backbone出力: 4層ぶんの特徴
feats = (
(patch_tokens_L1, camera_token_L1),
(patch_tokens_L2, camera_token_L2),
(patch_tokens_L3, camera_token_L3),
(patch_tokens_L4, camera_token_L4),
)
patch_tokens: [B, S, N_patch, C] もしくは [B, S, N_patch, 2C](cat_token=Trueのとき)
図にするとこうです(“4層をheadへ渡す” のいつものDPT系):
dark mode向け(背景暗め)版
ポイントは headは “per-view decode”(後述)で、view融合はbackbone側でやる、という分業です。
2. どこでviewが混ざる?(答え:reshapeだけ)
実装の中心は src/depth_anything_3/model/dinov2/vision_transformer.py の process_attention() です。
中身は本質的にこれだけ:
# x: [B, S, N, C]
if attn_type == "local":
x = rearrange(x, "b s n c -> (b s) n c") # [B*S, N, C]
elif attn_type == "global":
x = rearrange(x, "b s n c -> b (s n) c") # [B, S*N, C]
x = block(x) # ← ここは普通のViTブロック(Self-Attention)
# もどす
if attn_type == "local":
x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S)
elif attn_type == "global":
x = rearrange(x, "b (s n) c -> b s n c", b=B, s=S)
つまりglobalブロックでは、Self-Attentionのキー/バリューが S viewぶん 同じ系列内 に入るので、view間で情報が混ざります。
dark mode向け(背景暗め)版
2.1 “cross-attnっぽい” の正体
cross-attentionを明示的に書いていなくても、globalブロックは:
- Query: あるviewのtoken
- Key/Value: 全viewのtoken
という状況になるので、結果として cross-view mixing が起きます。
(「cross-attention層を足す」より、だいぶミニマルです)
3. いつglobalになる?(alt_start と odd block)
ずっとglobalにすれば最強かというと、計算量が爆発します。
そこでV3は alt_start 以降だけ、しかも 奇数indexブロックだけglobal にしています。
実装(概念)はこんな感じ:
if alt_start != -1 and i >= alt_start and i % 2 == 1:
attn = "global"
else:
attn = "local"
例えば da3-large.yaml(depth=24, alt_start=8)だと:
- global: 9,11,13,15,17,19,21,23(8個)
- local : それ以外(16個)
で、local:global ≒ 2:1 になります。
図にするとこういうスケジュールです。8層目以降から交互に起こってます。
3.1 なぜ交互が効くのか
直感的にはこうです:
- localで「各viewの理解」を進める(view独立に特徴を整える)
- globalで「view間の整合」をとる(同じシーンだと合意形成する)
を繰り返す。
この “localで整えてglobalで混ぜる” を、Transformer自身は改造せずにやってのけてます。
4. 4層特徴のexportと、cat_token=True の意味
DA3は、backboneから「4つの中間層特徴」をheadに渡します(DPT系の定番)。
ここで重要なのが cat_token=True の挙動です。
export時、local表現とglobal表現を特徴次元でconcat します:
out_x = torch.cat([local_x, x], dim=-1) # [B,S,N,2C]
-
x: その時点の最新状態(globalブロック後だと “混ざった表現”) -
local_x: 直近のlocalブロック後に保存された “混ざってない表現”
これを 両方headに投げる のが cat_token=True の意味です。
図で見るとこんな感じ:
dark mode向け(背景暗め)版
4.1 なぜexport層が “odd (global)” になりがちか
もしexportがlocalブロック直後だと、local_x == x になりがちで:
cat([x, x]) # ほぼ冗長
になります。
だからconfigの out_layers は、だいたいglobalになりやすい 奇数index を選んでいます。
例:da3-large.yaml の out_layers: [11, 15, 19, 23]
ちなみに、DA2のDPTに送る中間特徴レイヤーは、
self.intermediate_layer_idx = {
'vits': [2, 5, 8, 11],
'vitb': [2, 5, 8, 11],
'vitl': [4, 11, 17, 23],
'vitg': [9, 19, 29, 39]
}
なのでここではあえて4層目を使わずに15層目を使ってるのがわかります。この辺りは実験を繰り返して、性能の良いところを取ってきたんだろうと想像できます。
5. 位置情報の扱い(RoPEの “pos_nodiff” が地味に効いてる)
globalは S*N の系列にするので、位置エンコーディングの扱いを雑にやると破綻しがちです。
DA3は2種類の位置要素があります:
-
learned absolute pos_embed
- トークン準備段階で一回だけ足される
-
RoPE(2D rotary pos embedding)
- attention内部で使う(設定で有効/無効)
そしてRoPEが有効なとき、DA3は:
- localブロック: 通常の
(y, x)のposを使う - globalブロック:
pos_nodiff(ほぼ定数)を使う
という切替をしています。
目的は単純で、S*N を「巨大な2Dグリッド」と誤解させないこと。
global mixingを 位置依存にしすぎない ための工夫に見えます。
図にするとこう:
dark mode向け(背景暗め)版
6. おまけ:さらに工夫
まだいくつか面白い仕掛けがあります。
6.1 camera token を差し込む
alt_start のタイミングで x[:,:,0](cls token位置)に camera token を入れます。
「ここから先はmulti-viewとして扱う」スイッチと同時に、
camera条件をトークン列に注入しているイメージです。
cls tokenに忍ばせてくるとかこの開発者はどんだけ工夫してるんだって思います。
6.2 head側はviewを融合しない(“backboneで混ぜて、headは読むだけ”)
DPT/DualDPT のheadは、multi-view特徴を受け取りますが、やっていることは基本:
[B, S, ...] -> [B*S, ...] に畳んで 2D conv/pyramid を回す -> [B, S, ...] に戻す
なので view融合はheadでは起きません。
multi-viewの「整合の取り方」は、backboneのglobalブロックに責務が寄っています。
7. なぜ重いのか(globalは二乗で効く)
patch token数を N、view数を S とすると:
- local: S本の長さNのSelf-Attention(だいたい
S * N^2) - global: 1本の長さ(SN)のSelf-Attention(だいたい
(S*N)^2)
なので、S が増えるとglobalのコストが一気に支配的になります。
だからこそ:
- 後半だけglobal(
alt_start) - さらに交互だけglobal(odd blocks)
という設計が、実装としても思想としても開発者の逞しさを感じます。
8. まとめ
Depth Anything V3のbackbone multi-viewの肝は:
1. ViTブロックはそのまま(`[B,N,C]` APIを崩さない)
2. local/globalのreshapeでcross-view mixingを起こす
3. 後半だけ・交互だけglobalにして現実的な計算量に落とす
4. export時にlocal/globalをconcatしてheadに渡す(`cat_token=True`)
5. RoPEのposをlocal/globalで切り替えて破綻を避ける(`pos_nodiff`)