##はじめに
「End-to-End Object Detection with Transformers」(DETR) が気になったので、論文を読んで少し動作確認もしてみました。簡潔に記録として残しておきます。
[論文, Github]
##DETRとは(要約)
・Facebook AI Researchが今年5月に公開したモデル
・自然言語処理分野で有名なTransformerを初めて物体検出に活用
・下図のように、CNN + Transformerのシンプルなネットワーク構成
・NMSやAnchorBoxのデフォルト値等、人手による調整が必要な部分を排除し「End-to-End」な物体検出を実現
・上記を実現するためのポイントとして、"Bipartite Matching Loss"と"Parallel Decoding"の効果を主張
・物体検出だけでなく、セグメンテーションタスクへも適用可能
##推論コード例
論文よりコードを引用します。以下のように、モデル定義から推論処理までが40行程度でシンプルに書けます。
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(hidden_dim, nheads,
num_encoder_layers, num_decoder_layers)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1))
return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8,
num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
※この論文のコードのままだと、公開されている学習済みモデルをloadする事はできません。(model定義の書き方が違うため)
実際に学習済みモデルで検出を行う場合は、本家Githubのdetr_demo.ipynbのようにすると良いです。
##動作確認
学習済みモデルを使って実際に動作を確認してみました。環境は以下の通りです。
・OS : Ubuntu 18.04.4 LTS
・GPU : GeForce RTX 2060 SUPER(8GB) x1
・PyTorch 1.5.1 / torchvision 0.6.0
検出処理の実装は本家のdetr_demo.ipynbを参考にし、OpenCVでwebカメラからキャプチャした画像に対して検出を行いました。使用したモデルはResNet-50ベースのDETRです。
以下は実際の検出結果です。正常に検出が実行されていることが確認できました。(COCOのクラスに含まれない物も撮影対象にしてしまっていますが)
自分の環境では、推論処理自体は45msec(約22FPS)程度で回っているようでした。
※文献値は、V100上で28FPS
##おわり
DETRの勉強と動作確認を行いました。新しいタイプの手法なので、精度や速度の面でこれからさらに発展していくのかもと思います。
Transformerは自然言語処理専用というようなイメージを持っていましたが、最近では画像を扱うモデルへの導入も進んでいるのですね。今後、もっと勉強していきたいと思います。(Image GPTも動かしてみたいです。)
##参考
・「DETR」Transformerの物体検出デビュー
https://medium.com/lsc-psd/detr-transformer%E3%81%AE%E7%89%A9%E4%BD%93%E6%A4%9C%E5%87%BA%E3%83%87%E3%83%93%E3%83%A5%E3%83%BC-dc18e582dec1
・End-to-End Object Detection with Transformers (DETR) の解説
https://qiita.com/sasgawy/items/61fb64d848df9f6b53d1
・Transformer を物体検出に採用!話題のDETRを詳細解説!
https://deepsquare.jp/2020/07/detr/