0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

LambdaNetworks

Posted at

LambdaNetworksとは

簡単に解説してみる

  • この動画がとっっっても分かりやすいので,おすすめ.(1時間と長尺&英語だが,とにかく分かりやすいので気にならない,はず.)

イメージ

image.png

  • いかにGlobal Contextを捉えるか,というのが命題.
  • AttentionよりもLambdaの方が大分抽象的.
  • Attentionの場合,一つのピクセルごとにAttention Mapを求める必要があるため,コンピューティングコストが高い.
  • Lambdaでは一度だけGlobal Contextに対するLambdaを求めて,Queryとの積を取る.
  • Attentionではそれぞれのピクセルがそれぞれのピクセルに対する重要度を算出するが,Lambdaでは一度だけ集約されたLambdaを求めるだけでいい.
  • QueryはInput(InputをConv2dしてQ=XWとするが.また後述.).

Attention

Lambdaについて説明するために,まずAttentionについて復習する.AttentionはTransformerでも使われているので,既にご理解しているかと思うが,一応.

image.png

  • 上図がこれ以上説明する必要がないくらい簡潔にまとまっている.ポイントは,
    • memoryではkeyとvalueという形で二度使うこと
    • queryとkeyの内積を取って(ベクトルが似ていれば値が大きくなることを利用)softmaxで非線形化してAttention Mapを得ること
    • 再びvalueと積を取ってoutputとする
    • inputとmemoryはしばしば同じ
    • 式にすると以下.

image.png

  • 上記をさらにLambdaNetworksの説明用に書き直すと下図.
    • Context≒memory

image.png

Lambda (content lambda)

image.png

  • 上図がLambdaの動きを表した図であり,Attentionと比較すると違いが分かりやすい.
  • 違いは,QKに対してではなく,Keyのみに対してsoftmaxを取ること.
  • これによって,contentのsummarizeのような働きをしている.(最初に説明したLambdaの抽象的な図)
  • Cからλに変換するときに次元がmからkに減っていることからも,情報が集約されていることが分かる.
  • さらにこれをQueryと積をとってoutputとする.
  • 要は,Attentionを多分に参考にしてAttentionみたいでAttentionとは異なるアーキテクチャを提案したということ.

position lambda

  • 上記content lambdaではコンテキストをsummarizeすることができた一方で,画像の位置関係を捉えることができていない.
  • そこで,content lambdaと合わせて導入したのが,このposition lambda.

image.png

全体像

  • 上記での説明を合わせると以下のようになる.

image.png

  • 数式としては以下.

image.png
image.png

結果

  • パラメータ数が減っているにも関わらず,精度が向上している.

image.png

コード

  • Githubに非公式ながらPyTorch実装があるので参考にされたし.
  • 簡単にまとまっているので,一読をおすすめする.
  • 以下に載せているのは,Lambda Layerの実装.
  • 全体の実装は他の人がしていた.
import torch
from torch import nn, einsum
from einops import rearrange

# helpers functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def calc_rel_pos(n):
    pos = torch.meshgrid(torch.arange(n), torch.arange(n))
    pos = rearrange(torch.stack(pos), 'n i j -> (i j) n')  # [n*n, 2] pos[n] = (i, j)
    rel_pos = pos[None, :] - pos[:, None]                  # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
    rel_pos += n - 1                                       # shift value range from [-n+1, n-1] to [0, 2n-2]
    return rel_pos

# lambda layer

class LambdaLayer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_k,
        n = None,
        r = None,
        heads = 4,
        dim_out = None,
        dim_u = 1):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.u = dim_u # intra-depth dimension
        self.heads = heads

        assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
        dim_v = dim_out // heads

        self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
        self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)

        self.norm_q = nn.BatchNorm2d(dim_k * heads)
        self.norm_v = nn.BatchNorm2d(dim_v * dim_u)

        self.local_contexts = exists(r)
        if exists(r):
            assert (r % 2) == 1, 'Receptive kernel size should be odd'
            self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
        else:
            assert exists(n), 'You must specify the window size (n=h=w)'
            rel_lengths = 2 * n - 1
            self.rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k, dim_u))
            self.rel_pos = calc_rel_pos(n)

    def forward(self, x):
        b, c, hh, ww, u, h = *x.shape, self.u, self.heads

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = self.norm_q(q)
        v = self.norm_v(v)

        q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
        k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
        v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)

        k = k.softmax(dim=-1)

        λc = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b h v n', q, λc)

        if self.local_contexts:
            v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
            λp = self.pos_conv(v)
            Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
        else:
            n, m = self.rel_pos.unbind(dim = -1)
            rel_pos_emb = self.rel_pos_emb[n, m]
            λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b h v n', q, λp)

        Y = Yc + Yp
        out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
        return out

einsum

  • 超簡単にテンソルの積の演算をしてくれる.
    • torch.matmulなどでは次元を合わせるのが大変だった.
    • 他にもtorch.transposetorch.viewtorch.squeezeなども使っていた
  • アインシュタインの縮約記法に基づいて,多次元線形代数配列演算を簡略形式で表せる.
import torch as t

X = t.rand(3, 10, 5)
Y = t.rand(3, 20, 5)

t.einsum('bnm, bkm -> bnk', X, Y).size()
>> torch.Size([3, 10, 20])

t.einsum('bnm, bkm -> bkn', X, Y).size()
>> torch.Size([3, 20, 10])
  • 感動的に便利!!! 今まで知らなくてめちゃめちゃ損していた!
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?