LambdaNetworksとは
- ICLR2021採択論文
- Attentionなしに,効率性&精度の両面でAttentionやConvolutionベースのモデルを凌駕
- 研究者界隈では有名.思ったよりは話題になりきらなかった印象.
- Paper: https://arxiv.org/abs/2102.08602
- Github: https://github.com/lucidrains/lambda-networks
簡単に解説してみる
- この動画がとっっっても分かりやすいので,おすすめ.(1時間と長尺&英語だが,とにかく分かりやすいので気にならない,はず.)
イメージ
- いかにGlobal Contextを捉えるか,というのが命題.
- AttentionよりもLambdaの方が大分抽象的.
- Attentionの場合,一つのピクセルごとにAttention Mapを求める必要があるため,コンピューティングコストが高い.
- Lambdaでは一度だけGlobal Contextに対するLambdaを求めて,Queryとの積を取る.
- Attentionではそれぞれのピクセルがそれぞれのピクセルに対する重要度を算出するが,Lambdaでは一度だけ集約されたLambdaを求めるだけでいい.
- QueryはInput(InputをConv2dしてQ=XWとするが.また後述.).
Attention
Lambdaについて説明するために,まずAttentionについて復習する.AttentionはTransformerでも使われているので,既にご理解しているかと思うが,一応.
- 上図がこれ以上説明する必要がないくらい簡潔にまとまっている.ポイントは,
- memoryではkeyとvalueという形で二度使うこと
- queryとkeyの内積を取って(ベクトルが似ていれば値が大きくなることを利用)softmaxで非線形化してAttention Mapを得ること
- 再びvalueと積を取ってoutputとする
- inputとmemoryはしばしば同じ
- 式にすると以下.
- 上記をさらにLambdaNetworksの説明用に書き直すと下図.
- Context≒memory
Lambda (content lambda)
- 上図がLambdaの動きを表した図であり,Attentionと比較すると違いが分かりやすい.
- 違いは,QKに対してではなく,Keyのみに対してsoftmaxを取ること.
- これによって,contentのsummarizeのような働きをしている.(最初に説明したLambdaの抽象的な図)
- Cからλに変換するときに次元がmからkに減っていることからも,情報が集約されていることが分かる.
- さらにこれをQueryと積をとってoutputとする.
- 要は,Attentionを多分に参考にしてAttentionみたいでAttentionとは異なるアーキテクチャを提案したということ.
position lambda
- 上記content lambdaではコンテキストをsummarizeすることができた一方で,画像の位置関係を捉えることができていない.
- そこで,content lambdaと合わせて導入したのが,このposition lambda.
全体像
- 上記での説明を合わせると以下のようになる.
- 数式としては以下.
結果
- パラメータ数が減っているにも関わらず,精度が向上している.
コード
- Githubに非公式ながらPyTorch実装があるので参考にされたし.
- 簡単にまとまっているので,一読をおすすめする.
- 以下に載せているのは,Lambda Layerの実装.
- 全体の実装は他の人がしていた.
import torch
from torch import nn, einsum
from einops import rearrange
# helpers functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def calc_rel_pos(n):
pos = torch.meshgrid(torch.arange(n), torch.arange(n))
pos = rearrange(torch.stack(pos), 'n i j -> (i j) n') # [n*n, 2] pos[n] = (i, j)
rel_pos = pos[None, :] - pos[:, None] # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2]
return rel_pos
# lambda layer
class LambdaLayer(nn.Module):
def __init__(
self,
dim,
*,
dim_k,
n = None,
r = None,
heads = 4,
dim_out = None,
dim_u = 1):
super().__init__()
dim_out = default(dim_out, dim)
self.u = dim_u # intra-depth dimension
self.heads = heads
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
dim_v = dim_out // heads
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
self.norm_q = nn.BatchNorm2d(dim_k * heads)
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
self.local_contexts = exists(r)
if exists(r):
assert (r % 2) == 1, 'Receptive kernel size should be odd'
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
else:
assert exists(n), 'You must specify the window size (n=h=w)'
rel_lengths = 2 * n - 1
self.rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k, dim_u))
self.rel_pos = calc_rel_pos(n)
def forward(self, x):
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = self.norm_q(q)
v = self.norm_v(v)
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
k = k.softmax(dim=-1)
λc = einsum('b u k m, b u v m -> b k v', k, v)
Yc = einsum('b h k n, b k v -> b h v n', q, λc)
if self.local_contexts:
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
λp = self.pos_conv(v)
Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
else:
n, m = self.rel_pos.unbind(dim = -1)
rel_pos_emb = self.rel_pos_emb[n, m]
λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
Y = Yc + Yp
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
return out
einsum
- 超簡単にテンソルの積の演算をしてくれる.
-
torch.matmul
などでは次元を合わせるのが大変だった. - 他にも
torch.transpose
やtorch.view
,torch.squeeze
なども使っていた
-
- アインシュタインの縮約記法に基づいて,多次元線形代数配列演算を簡略形式で表せる.
import torch as t
X = t.rand(3, 10, 5)
Y = t.rand(3, 20, 5)
t.einsum('bnm, bkm -> bnk', X, Y).size()
>> torch.Size([3, 10, 20])
t.einsum('bnm, bkm -> bkn', X, Y).size()
>> torch.Size([3, 20, 10])
- 感動的に便利!!! 今まで知らなくてめちゃめちゃ損していた!