0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【処理の流れ】RAFT(Recurrent All-Pairs Field Transforms)

Last updated at Posted at 2024-11-25

当記事ではSintelやKITTIで2020年当時のSotAを実現したOptical Flowの研究であるRAFT(Recurrent All-Pairs Field Transforms)の処理の流れについて、PyTorchベースの論文著者実装を元に取りまとめを行いました。

概要

RAFTのリポジトリ

RAFT GitHub

上記はPyTorchベースで構築されたRAFTの論文著者実装です。当記事では上記の実装の確認を行います。

RAFTの処理概要

RAFT1.png
RAFT論文Figure.1

RAFTの処理概要は上記を元に把握すると良いです。「1. Encoder」、「2. 4D($W \times H \times W \times H$)の相関の計算」、「3. RNNを用いたRefinement」の3つの処理に基づいてRAFTは構築されます。

実際の処理の確認にあたってはdemo.pyから詳細の実装について確認します。demo.pyの主な処理を下記に抽出しました。

PJ_Root/demo.py
import sys
sys.path.append('core')
from raft import RAFT
((中略))
model = torch.nn.DataParallel(RAFT(args))
((中略))
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)

上記より、PJ_Root/core/raft.pyに定義されたRAFTクラスを読み込み処理が行われていることが確認できます。よって、次節ではPJ_Root/core/raft.pyの詳細の確認を行います。

詳細

Encoderの処理

PJ_Root/core/raft.py
from extractor import BasicEncoder, SmallEncoder

class RAFT(nn.Module):
    def __init__(self, args):
        if args.small:
            self.hidden_dim = hdim = 96
            self.context_dim = cdim = 64
            args.corr_levels = 4
            args.corr_radius = 3

        # feature network, context network, and update block
        if args.small:
            self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)        
            self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
            self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)

    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
        """ Estimate optical flow between pair of frames """

        image1 = 2 * (image1 / 255.0) - 1.0
        image2 = 2 * (image2 / 255.0) - 1.0

        image1 = image1.contiguous()
        image2 = image2.contiguous()

        hdim = self.hidden_dim
        cdim = self.context_dim

        # run the feature network
        with autocast(enabled=self.args.mixed_precision):
            fmap1, fmap2 = self.fnet([image1, image2])

RAFTのEncoderの処理は上記のself.fnet([image1, image2])が主に対応します。SmallEncoderに基づいてself.fnetが実行されるので、読み込み元のPJ_Root/core/extractor.pyを次に確認します。

PJ_Root/core/extractor.py
class SmallEncoder(nn.Module):
    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
        super(SmallEncoder, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
        self.relu1 = nn.ReLU(inplace=True)

        self.in_planes = 32
        self.layer1 = self._make_layer(32,  stride=1)
        self.layer2 = self._make_layer(64, stride=2)
        self.layer3 = self._make_layer(96, stride=2)

        self.dropout = None
        if dropout > 0:
            self.dropout = nn.Dropout2d(p=dropout)
        
        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)

    def _make_layer(self, dim, stride=1):
        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
        layers = (layer1, layer2)
    
        self.in_planes = dim
        return nn.Sequential(*layers)

    def forward(self, x):
        # if input is list, combine batch dimension
        is_list = isinstance(x, tuple) or isinstance(x, list)
        if is_list:
            batch_dim = x[0].shape[0]
            x = torch.cat(x, dim=0)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.conv2(x)

        if self.training and self.dropout is not None:
            x = self.dropout(x)

        if is_list:
            x = torch.split(x, [batch_dim, batch_dim], dim=0)

        return x

上記のself.conv1self.conv2torch.nn.Conv2dが用いられていることなどから、Encoderは基本的にCNNによって構築されることが確認できます。また、self.layer1self.layer3に対応する_make_layerには下記のように定義されるBottleneckBlockが用いられます。

PJ_Root/core/extractor.py
class BottleneckBlock(nn.Module):
    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
        super(BottleneckBlock, self).__init__()
  
        self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
        self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
        self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
        self.relu = nn.ReLU(inplace=True)

        num_groups = planes // 8
        ...

    def forward(self, x):
        y = x
        y = self.relu(self.norm1(self.conv1(y)))
        y = self.relu(self.norm2(self.conv2(y)))
        y = self.relu(self.norm3(self.conv3(y)))

        if self.downsample is not None:
            x = self.downsample(x)

        return self.relu(x+y)

上記でもtorch.nn.Conv2dが用いられていることが確認できます。

4Dの相関の計算

PJ_Root/core/raft.py
from corr import CorrBlock, AlternateCorrBlock

class RAFT(nn.Module):
    ...
    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
        ...
        # run the feature network
        with autocast(enabled=self.args.mixed_precision):
            fmap1, fmap2 = self.fnet([image1, image2])        
        
        fmap1 = fmap1.float()
        fmap2 = fmap2.float()
        if self.args.alternate_corr:
            corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
        else:
            corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
        ...
        coords0, coords1 = self.initialize_flow(image1)

        if flow_init is not None:
            coords1 = coords1 + flow_init

        flow_predictions = []
        for itr in range(iters):
            coords1 = coords1.detach()
            corr = corr_fn(coords1) # index correlation volume

上記より、PJ_Root/core/raft.pyにおける相関の計算と特徴量の抽出についての処理が確認できます。上記の詳細の処理はPJ_Root/core/corr.pyより確認できます。

PJ_Root/core/corr.py
from utils.utils import bilinear_sampler, coords_grid

class CorrBlock:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = []

        # all pairs correlation
        corr = CorrBlock.corr(fmap1, fmap2)

        batch, h1, w1, dim, h2, w2 = corr.shape
        corr = corr.reshape(batch*h1*w1, dim, h2, w2)
        
        self.corr_pyramid.append(corr)
        for i in range(self.num_levels-1):
            corr = F.avg_pool2d(corr, 2, stride=2)
            self.corr_pyramid.append(corr)

    def __call__(self, coords):
        r = self.radius
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape

        out_pyramid = []
        for i in range(self.num_levels):
            corr = self.corr_pyramid[i]
            dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
            dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)

            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
            coords_lvl = centroid_lvl + delta_lvl

            corr = bilinear_sampler(corr, coords_lvl)
            corr = corr.view(batch, h1, w1, -1)
            out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

    @staticmethod
    def corr(fmap1, fmap2):
        batch, dim, ht, wd = fmap1.shape
        fmap1 = fmap1.view(batch, dim, ht*wd)
        fmap2 = fmap2.view(batch, dim, ht*wd) 
        
        corr = torch.matmul(fmap1.transpose(1,2), fmap2)
        corr = corr.view(batch, ht, wd, 1, ht, wd)
        return corr  / torch.sqrt(torch.tensor(dim).float())

上記について特に確認すべきなのはbilinear_samplerの入力と出力です。bilinear_samplerではtorch.nn.functional.grid_sampleを用いることで、繰り返し演算(refinement)における暫定的なベクトルを元に4DからのTargetの特徴マップの$W \times H$から$7 \times 7$の領域の抽出を行います。

grid_sample.png
grid_sample(PyTorchドキュメント)

上記はtorch.nn.functional.grid_sampleのドキュメントです。引数には主にinputgridがあり、$W_{in} \times H_{in}$のInputの領域から$W_{out} \times H_{out}$のグリッドを抽出する処理が行われます。RAFTのデモ用のSintelのFeature mapでは$W_{in}=128, H_{in}=55$であり、グリッドには$W_{out}=H_{out}=7$または$W_{out}=H_{out}=9$が用いられます。

RNNを用いたRefinement

PJ_Root/core/raft.py
from update import BasicUpdateBlock, SmallUpdateBlock

class RAFT(nn.Module):
    def __init__(self, args):
    ...
    # feature network, context network, and update block
        if args.small:
            self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)        
            self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
            self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)

    ...

    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
        ...
        flow_predictions = []
        for itr in range(iters):
            coords1 = coords1.detach()
            corr = corr_fn(coords1) # index correlation volume
            self.corr = corr
            #print(corr.shape)

            flow = coords1 - coords0
            with autocast(enabled=self.args.mixed_precision):
                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)

            # F(t+1) = F(t) + \Delta(t)
            coords1 = coords1 + delta_flow

            # upsample predictions
            if up_mask is None:
                flow_up = upflow8(coords1 - coords0)
            else:
                flow_up = self.upsample_flow(coords1 - coords0, up_mask)
            
            flow_predictions.append(flow_up)

上記より、PJ_Root/core/raft.pyにおけるRNNを用いたRefinementについての処理が確認できます。上記のメイン処理であるself.update_blockの引数のnetinpはContext Encoderの出力、corrは4D Correlation、flowはゼロ行列を初期値とする暫定的なflowにそれぞれ対応します。self.update_blockの処理の詳細はPJ_Root/core/update.pyより確認できます。

PJ_Root/core/update.py
class SmallUpdateBlock(nn.Module):
    def __init__(self, args, hidden_dim=96):
        super(SmallUpdateBlock, self).__init__()
        self.encoder = SmallMotionEncoder(args)
        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
        self.flow_head = FlowHead(hidden_dim, hidden_dim=128)

    def forward(self, net, inp, corr, flow):
        motion_features = self.encoder(flow, corr)
        inp = torch.cat([inp, motion_features], dim=1)
        net = self.gru(net, inp)
        delta_flow = self.flow_head(net)

        return net, None, delta_flow

上記のself.gru(ConvGRU)でRNNのメイン処理が行われているので、以下ConvGRUについて確認します。

PJ_Root/core/update.py
class ConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192+128):
        super(ConvGRU, self).__init__()
        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)

    def forward(self, h, x):
        hx = torch.cat([h, x], dim=1)

        z = torch.sigmoid(self.convz(hx))
        r = torch.sigmoid(self.convr(hx))
        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))

        h = (1-z) * h + z * q
        return h

self.convzself.convrself.convqのいずれもtorch.nn.Conv2dの出力のチャネル数がhidden_dimで指定されているので、h = (1-z) * h + z * qのようなGRUの計算を行うことができると理解しておくと良いです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?