当記事ではSintelやKITTIで2020年当時のSotAを実現したOptical Flowの研究であるRAFT(Recurrent All-Pairs Field Transforms)の処理の流れについて、PyTorchベースの論文著者実装を元に取りまとめを行いました。
概要
RAFTのリポジトリ
上記はPyTorchベースで構築されたRAFTの論文著者実装です。当記事では上記の実装の確認を行います。
RAFTの処理概要
RAFT論文Figure.1
RAFTの処理概要は上記を元に把握すると良いです。「1. Encoder」、「2. 4D($W \times H \times W \times H$)の相関の計算」、「3. RNNを用いたRefinement」の3つの処理に基づいてRAFTは構築されます。
実際の処理の確認にあたってはdemo.py
から詳細の実装について確認します。demo.py
の主な処理を下記に抽出しました。
import sys
sys.path.append('core')
from raft import RAFT
((中略))
model = torch.nn.DataParallel(RAFT(args))
((中略))
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
上記より、PJ_Root/core/raft.py
に定義されたRAFT
クラスを読み込み処理が行われていることが確認できます。よって、次節ではPJ_Root/core/raft.py
の詳細の確認を行います。
詳細
Encoderの処理
from extractor import BasicEncoder, SmallEncoder
class RAFT(nn.Module):
def __init__(self, args):
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
RAFTのEncoderの処理は上記のself.fnet([image1, image2])
が主に対応します。SmallEncoder
に基づいてself.fnet
が実行されるので、読み込み元のPJ_Root/core/extractor.py
を次に確認します。
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
上記のself.conv1
やself.conv2
にtorch.nn.Conv2d
が用いられていることなどから、Encoderは基本的にCNNによって構築されることが確認できます。また、self.layer1
~self.layer3
に対応する_make_layer
には下記のように定義されるBottleneckBlock
が用いられます。
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
...
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
上記でもtorch.nn.Conv2d
が用いられていることが確認できます。
4Dの相関の計算
from corr import CorrBlock, AlternateCorrBlock
class RAFT(nn.Module):
...
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
...
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
...
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
上記より、PJ_Root/core/raft.py
における相関の計算と特徴量の抽出についての処理が確認できます。上記の詳細の処理はPJ_Root/core/corr.py
より確認できます。
from utils.utils import bilinear_sampler, coords_grid
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
上記について特に確認すべきなのはbilinear_sampler
の入力と出力です。bilinear_sampler
ではtorch.nn.functional.grid_sample
を用いることで、繰り返し演算(refinement)における暫定的なベクトルを元に4DからのTargetの特徴マップの$W \times H$から$7 \times 7$の領域の抽出を行います。
上記はtorch.nn.functional.grid_sample
のドキュメントです。引数には主にinput
とgrid
があり、$W_{in} \times H_{in}$のInputの領域から$W_{out} \times H_{out}$のグリッドを抽出する処理が行われます。RAFTのデモ用のSintelのFeature mapでは$W_{in}=128, H_{in}=55$であり、グリッドには$W_{out}=H_{out}=7$または$W_{out}=H_{out}=9$が用いられます。
RNNを用いたRefinement
from update import BasicUpdateBlock, SmallUpdateBlock
class RAFT(nn.Module):
def __init__(self, args):
...
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
...
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
...
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
self.corr = corr
#print(corr.shape)
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
上記より、PJ_Root/core/raft.py
におけるRNNを用いたRefinementについての処理が確認できます。上記のメイン処理であるself.update_block
の引数のnet
とinp
はContext Encoderの出力、corr
は4D Correlation、flow
はゼロ行列を初期値とする暫定的なflow
にそれぞれ対応します。self.update_block
の処理の詳細はPJ_Root/core/update.py
より確認できます。
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
上記のself.gru
(ConvGRU
)でRNNのメイン処理が行われているので、以下ConvGRU
について確認します。
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
self.convz
、self.convr
、self.convq
のいずれもtorch.nn.Conv2d
の出力のチャネル数がhidden_dim
で指定されているので、h = (1-z) * h + z * q
のようなGRUの計算を行うことができると理解しておくと良いです。