1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

RefineNetの論文と実装例の確認

Last updated at Posted at 2025-09-10

当記事ではRefineNetの論文とDepthAnythingのPyTorch実装におけるRefineNetの実装の確認を行います。

RefineNetの論文

概要

RefineNetはCNNを用いてセグメンテーションなどのDense Classificationタスクを解く際にダウンサンプリングによって起こる情報の損失について取り組んだ研究です。

RefineNet1.png
RefineNet論文 Figure.1

セグメンテーションタスクは上図のように意味に基づいて領域の分割を行います。Object Detenctionに用いられるFPN(Feature Pyramid Network)と処理が類似するので合わせて抑えておくと良いと思います。

RefineNet2.png
RefineNet論文 Figure.2

RefineNetでは上図の(c)のように異なるスケールの特徴マップ(Feature Map)を徐々にアップサンプリングしていくことでセグメンテーションを行います。このようなRefineNetの処理はセグメンテーションの処理だけでなくDepth Anythingのような深度推定の手法に応用することもできます(セグメンテーションと深度推定のどちらもDense Prediction)。

RefineNetの処理の流れ

RefineNet3.png
RefineNet論文 Figure.3

RefineNetの処理の流れの大枠は上図より確認することができます。詳しい処理については以下で確認します。

ResNetを用いた入力の作成

RefineNet4.png
RefineNet論文 Figure.3を改変

RefineNetではResNetの出力から4つのスケールの特徴マップを抽出し入力に用います。ResNet block-1ResNet block-4がそれぞれの解像度(1の特徴マップが一番fine-grainedであり4が最も粗い)をRefineNet-1RefineNet-4にそれぞれ入力します。

また、RefineNet-mは$m=4$の場合はResNet block-4のみを入力に取る一方で、$m=1,2,3$の場合はResNet block-mと合わせてRefineNet-m+1の出力も入力に取ることに着目しておくと良いです。

Residual Conv Unit(RCU)

RefineNet5.png
RefineNet論文 Figure.3を改変

RCU(Residual Conv Unit)では2層のResNetの処理が実行されます。

Multi-resolution Fusion

RefineNet6.png
RefineNet論文 Figure.3を改変

RefineNetの実装

以下、上記で確認したDepth Anythingのリポジトリを元にRefineNetの実装について確認します。

RefineNetが使用されている箇所の確認

Depth-Anything/depth_anything/dpt.py
from depth_anything.blocks import FeatureFusionBlock, _make_scratch

def _make_fusion_block(features, use_bn, size = None):
    return FeatureFusionBlock(
        features,
        nn.ReLU(False),
        deconv=False,
        bn=use_bn,
        expand=False,
        align_corners=True,
        size=size,
    )

class DPTHead(nn.Module):
    def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
        super(DPTHead, self).__init__()
        
        self.nclass = nclass
        self.use_clstoken = use_clstoken
        
        ...
        
        self.scratch = _make_scratch(
            out_channels,
            features,
            groups=1,
            expand=False,
        )

        self.scratch.stem_transpose = None
        
        self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
        self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
        self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
        self.scratch.refinenet4 = _make_fusion_block(features, use_bn)

        ...
            
    def forward(self, out_features, patch_h, patch_w):
        out = []
        for i, x in enumerate(out_features):
            if self.use_clstoken:
                x, cls_token = x[0], x[1]
                readout = cls_token.unsqueeze(1).expand_as(x)
                x = self.readout_projects[i](torch.cat((x, readout), -1))
            else:
                x = x[0]
            
            x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
            
            x = self.projects[i](x)
            x = self.resize_layers[i](x)
            
            out.append(x)
        
        layer_1, layer_2, layer_3, layer_4 = out
        
        layer_1_rn = self.scratch.layer1_rn(layer_1)
        layer_2_rn = self.scratch.layer2_rn(layer_2)
        layer_3_rn = self.scratch.layer3_rn(layer_3)
        layer_4_rn = self.scratch.layer4_rn(layer_4)
        
        path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
        path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
        path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
        
        out = self.scratch.output_conv1(path_1)
        out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
        out = self.scratch.output_conv2(out)
        
        return out

上記のself.scratch.refinenet1self.scratch.refinenet4でRefineNetの処理が実行されていることが確認できます。self.scratch.refinenet1self.scratch.refinenet4はそれぞれ_make_fusion_block関数の内部でFeatureFusionBlockクラス(depth_anything/blocks.py)に基づくオブジェクトが作成されています。

FeatureFusionBlockの実装

depth_anything/blocks.py
class FeatureFusionBlock(nn.Module):
    """Feature fusion block.
    """

    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
        """Init.

        Args:
            features (int): number of features
        """
        super(FeatureFusionBlock, self).__init__()

        self.deconv = deconv
        self.align_corners = align_corners

        self.groups=1

        self.expand = expand
        out_features = features
        if self.expand==True:
            out_features = features//2
        
        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)

        self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
        self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
        
        self.skip_add = nn.quantized.FloatFunctional()

        self.size=size

    def forward(self, *xs, size=None):
        """Forward pass.

        Returns:
            tensor: output
        """
        output = xs[0]

        if len(xs) == 2:
            res = self.resConfUnit1(xs[1])
            output = self.skip_add.add(output, res)

        output = self.resConfUnit2(output)

        if (size is None) and (self.size is None):
            modifier = {"scale_factor": 2}
        elif size is None:
            modifier = {"size": self.size}
        else:
            modifier = {"size": size}

        output = nn.functional.interpolate(
            output, **modifier, mode="bilinear", align_corners=self.align_corners
        )

        output = self.out_conv(output)

        return output

上記の処理を確認するにあたっては、まずlen(xs)1であるか2であるかに着目すると良いです。基本的にRefineNet-4ではlen(xs)1RefineNet-1RefineNet-3ではlen(xs)2になります(解像度が粗い順に処理を行い、1つ上の解像度の特徴マップに加える)。

self.resConfUnit1self.resConfUnit2は前節で確認したResidual Conv Unitnn.functional.interpolate(アップサンプリング)とself.skip_add.add(output, res)(2つの特徴マップを組み合わせる)によってMulti-resolution Fusionがそれぞれ実装されています。

python:depth_anything/blocks.py
class ResidualConvUnit(nn.Module):
    """Residual convolution module.
    """

    def __init__(self, features, activation, bn):
        """Init.

        Args:
            features (int): number of features
        """
        super().__init__()
        self.bn = bn
        self.groups=1
        self.conv1 = nn.Conv2d(
            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
        )
        self.conv2 = nn.Conv2d(
            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
        )

        if self.bn==True:
            self.bn1 = nn.BatchNorm2d(features)
            self.bn2 = nn.BatchNorm2d(features)

        self.activation = activation
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        """Forward pass.
        Args:
            x (tensor): input
        Returns:
            tensor: output
        """
        
        out = self.activation(x)
        out = self.conv1(out)
        if self.bn==True:
            out = self.bn1(out)
       
        out = self.activation(out)
        out = self.conv2(out)
        if self.bn==True:
            out = self.bn2(out)

        if self.groups > 1:
            out = self.conv_merge(out)

        return self.skip_add.add(out, x)

上記よりFeatureFusionBlockで用いられるResidualConvUnitが2層のCNNであることも確認できます。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?