LoginSignup
8
12

More than 3 years have passed since last update.

【要約】Transformerを用いた物体検出モデル「End-to-End Object Detection with Transformers」

Posted at

はじめに

「End-to-End Object Detection with Transformers」(DETR) が気になったので、論文を読んで少し動作確認もしてみました。簡潔に記録として残しておきます。
[論文, Github]

DETRとは(要約)

・Facebook AI Researchが今年5月に公開したモデル

・自然言語処理分野で有名なTransformerを初めて物体検出に活用

・下図のように、CNN + Transformerのシンプルなネットワーク構成

・NMSやAnchorBoxのデフォルト値等、人手による調整が必要な部分を排除し「End-to-End」な物体検出を実現

・上記を実現するためのポイントとして、"Bipartite Matching Loss"と"Parallel Decoding"の効果を主張

・物体検出だけでなく、セグメンテーションタスクへも適用可能

detr_flow.png

推論コード例

論文よりコードを引用します。以下のように、モデル定義から推論処理までが40行程度でシンプルに書けます。

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):

    def __init__(self, num_classes, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers):
        super().__init__()
        # We take only convolutional layers from ResNet-50 model
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        self.transformer = nn.Transformer(hidden_dim, nheads,
                                          num_encoder_layers, num_decoder_layers)
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        x = self.backbone(inputs)
        h = self.conv(x)
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1))
        return self.linear_class(h), self.linear_bbox(h).sigmoid()

detr = DETR(num_classes=91, hidden_dim=256, nheads=8,
            num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)

※この論文のコードのままだと、公開されている学習済みモデルをloadする事はできません。(model定義の書き方が違うため)
実際に学習済みモデルで検出を行う場合は、本家Githubのdetr_demo.ipynbのようにすると良いです。

動作確認

学習済みモデルを使って実際に動作を確認してみました。環境は以下の通りです。
・OS : Ubuntu 18.04.4 LTS
・GPU : GeForce RTX 2060 SUPER(8GB) x1
・PyTorch 1.5.1 / torchvision 0.6.0

検出処理の実装は本家のdetr_demo.ipynbを参考にし、OpenCVでwebカメラからキャプチャした画像に対して検出を行いました。使用したモデルはResNet-50ベースのDETRです。

以下は実際の検出結果です。正常に検出が実行されていることが確認できました。(COCOのクラスに含まれない物も撮影対象にしてしまっていますが)
frame_290.jpg

自分の環境では、推論処理自体は45msec(約22FPS)程度で回っているようでした。
※文献値は、V100上で28FPS

おわり

DETRの勉強と動作確認を行いました。新しいタイプの手法なので、精度や速度の面でこれからさらに発展していくのかもと思います。
Transformerは自然言語処理専用というようなイメージを持っていましたが、最近では画像を扱うモデルへの導入も進んでいるのですね。今後、もっと勉強していきたいと思います。(Image GPTも動かしてみたいです。)

参考

・「DETR」Transformerの物体検出デビュー
https://medium.com/lsc-psd/detr-transformer%E3%81%AE%E7%89%A9%E4%BD%93%E6%A4%9C%E5%87%BA%E3%83%87%E3%83%93%E3%83%A5%E3%83%BC-dc18e582dec1
・End-to-End Object Detection with Transformers (DETR) の解説
https://qiita.com/sasgawy/items/61fb64d848df9f6b53d1
・Transformer を物体検出に採用!話題のDETRを詳細解説!
https://deepsquare.jp/2020/07/detr/

8
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
12