当記事ではRefineNetの論文とDepthAnythingのPyTorch
実装におけるRefineNetの実装の確認を行います。
RefineNetの論文
概要
RefineNetはCNNを用いてセグメンテーションなどのDense Classificationタスクを解く際にダウンサンプリングによって起こる情報の損失について取り組んだ研究です。
セグメンテーションタスクは上図のように意味に基づいて領域の分割を行います。Object Detenctionに用いられるFPN(Feature Pyramid Network)と処理が類似するので合わせて抑えておくと良いと思います。
RefineNetでは上図の(c)のように異なるスケールの特徴マップ(Feature Map)を徐々にアップサンプリングしていくことでセグメンテーションを行います。このようなRefineNetの処理はセグメンテーションの処理だけでなくDepth Anythingのような深度推定の手法に応用することもできます(セグメンテーションと深度推定のどちらもDense Prediction)。
RefineNetの処理の流れ
RefineNetの処理の流れの大枠は上図より確認することができます。詳しい処理については以下で確認します。
ResNetを用いた入力の作成
RefineNetではResNetの出力から4つのスケールの特徴マップを抽出し入力に用います。ResNet block-1
〜ResNet block-4
がそれぞれの解像度(1の特徴マップが一番fine-grainedであり4が最も粗い)をRefineNet-1
〜RefineNet-4
にそれぞれ入力します。
また、RefineNet-m
は$m=4$の場合はResNet block-4
のみを入力に取る一方で、$m=1,2,3$の場合はResNet block-m
と合わせてRefineNet-m+1
の出力も入力に取ることに着目しておくと良いです。
Residual Conv Unit(RCU)
RCU(Residual Conv Unit)では2層のResNetの処理が実行されます。
Multi-resolution Fusion
RefineNetの実装
以下、上記で確認したDepth Anythingのリポジトリを元にRefineNetの実装について確認します。
RefineNetが使用されている箇所の確認
from depth_anything.blocks import FeatureFusionBlock, _make_scratch
def _make_fusion_block(features, use_bn, size = None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class DPTHead(nn.Module):
def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
super(DPTHead, self).__init__()
self.nclass = nclass
self.use_clstoken = use_clstoken
...
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
...
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
上記のself.scratch.refinenet1
〜self.scratch.refinenet4
でRefineNetの処理が実行されていることが確認できます。self.scratch.refinenet1
〜self.scratch.refinenet4
はそれぞれ_make_fusion_block
関数の内部でFeatureFusionBlock
クラス(depth_anything/blocks.py
)に基づくオブジェクトが作成されています。
FeatureFusionBlockの実装
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size=size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(
output, **modifier, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
上記の処理を確認するにあたっては、まずlen(xs)
が1
であるか2
であるかに着目すると良いです。基本的にRefineNet-4
ではlen(xs)
が1
、RefineNet-1
〜RefineNet-3
ではlen(xs)
が2
になります(解像度が粗い順に処理を行い、1つ上の解像度の特徴マップに加える)。
self.resConfUnit1
とself.resConfUnit2
は前節で確認したResidual Conv Unit
、nn.functional.interpolate
(アップサンプリング)とself.skip_add.add(output, res)
(2つの特徴マップを組み合わせる)によってMulti-resolution Fusionがそれぞれ実装されています。
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
上記よりFeatureFusionBlock
で用いられるResidualConvUnit
が2層のCNNであることも確認できます。