0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DETR(End-to-End Object Detection with Transformers)物体検知

Last updated at Posted at 2023-12-09

この動画で使っているプログラムを、こちらの記事に転記しています。

記事の概要

こちらの記事でGoogle Colab用に実装されている、DETRのファインチューニングプログラムをWindows用にカスタマイズしています。

また、自前で用意した画像データに対して、ラベリング(labelme)・ラベルの確認(coco-viewer)用の手順を追加し、自前データを用いたファインチューニングの手順も、別の記事として作成しています。

環境

OS:Windows 11
GPU:GeForce RTX 4090
CPU:i9-13900KF
memory:64G
python:3.10.10
pytorch:2.0.1
CUDA:11.8
cuDNN:8.8

以下の環境でも動作確認済み。
GPU:GeForce RTX 3060 laptop
CPU:i7-10750H
memory:16G

参考記事のWindows版

以下のコマンドを実行し、環境構築を構築する。

# 仮想環境を構築・アクティベイト
python -m venv detr_env
cd detr_env\Scripts
activate
cd ..

# 対象のライブラリをインストール
git clone https://github.com/EscVM/OIDv4_ToolKit.git
pip install urllib3==1.25.11 folium==0.2.1
pip install -r OIDv4_ToolKit/requirements.txt

# 学習用データをダウンロード
python OIDv4_ToolKit/main.py downloader -y --classes Apple Orange Ball Balloon Clock --type_csv train
python OIDv4_ToolKit/main.py downloader -y --classes Apple Orange Ball Balloon Clock --type_csv validation

以下のプログラムを実行し、学習用データをCOCO formatに変換する。

import os, json, glob

def OID2JSON(OIDFiles, saveName, subset):
    """
    アノテーションをOpenImage format(txt)からCOCO format(json)に変換
    Parameters
    ----------
    OIDFiles : string
        OpenImageDatasetのフォルダパス
    saveName : string
        保存ファイル名(json)
    subset : string
        変換したいtype_csv。train、validation、testのいずれかを指定。
    """
    attrDict = dict()
    # categories要素の設定
    attrDict['categories'] = []
    categories = sorted(os.listdir(os.path.join(OIDFiles, 'Dataset', subset)))
    for i in range(len(categories)):
        attrDict['categories'].append({'supercategory': 'none', 'id': i, 'name': categories[i]})
    
    images = list()
    annotations = list()
    filenames = list()
    image_id = 1
    anno_id = 1
    for category in attrDict['categories']:
        for jpg_file in glob.glob(os.path.join(OIDFiles, 'Dataset', subset, category['name'], '*.jpg')):
            filename = os.path.splitext(os.path.basename(jpg_file))[0]
            # カテゴリ全体で同じファイル名が存在する場合、imageとannoをリネーム
            if filename in filenames:
                rename_filename = filename + '_' + str(image_id)
                os.rename(jpg_file, os.path.join(OIDFiles, 'Dataset', subset, category['name'], rename_filename + '.jpg'))
                os.rename(os.path.join(OIDFiles, 'Dataset', subset, category['name'], 'Label', filename + '.txt'),
                                    os.path.join(OIDFiles, 'Dataset', subset, category['name'], 'Label', rename_filename + '.txt'))
                filename = rename_filename
            
            filenames.append(filename)
            # images要素の設定
            # ※DETRではheightとwidthを使わないので、'none'を設定
            image = {'file_name': filename + '.jpg', 'height': 'none', 'width': 'none', 'id': image_id}
            images.append(image)
            # annotations要素の設定
            anno_path = os.path.join(OIDFiles, 'Dataset', subset, category['name'], 'Label', filename + '.txt')
            with open(anno_path) as f:
                for line in f:
                    splitline = line.split(' ')
                    # カテゴリがcategories要素に存在しないバウンディングボックスは使わない
                    if splitline[0] in [d.get('name') for d in attrDict['categories']]:
                        # OpenImageの座標は(xmin, ymin, xmax, ymax)、COCOの座標は(x, y, width, height)
                        x1 = int(float(splitline[1]))
                        y1 = int(float(splitline[2]))
                        x2 = int(float(splitline[3])) - x1
                        y2 = int(float(splitline[4])) - y1
                        # areaはピクセル数(float)
                        area = float(x2 * y2)
                        # segmentationは(x1, y1, x2, y2, ...)と順番に定義
                        segmentation = [[x1, y1, x1, (y1+y2), (x1+x2), (y1+y2), (x1+x2), y1]]
                        annotation = {'iscrowd': 0, 'image_id': image_id, 'bbox': [x1, y1, x2, y2], 'area': area,
                                      'category_id': category['id'], 'ignore': 0, 'id': anno_id, 'segmentation': segmentation}
                        anno_id += 1
                        annotations.append(annotation)
                
            image_id = image_id + 1
    
    attrDict['images'] = images
    attrDict['annotations'] = annotations
    attrDict['type'] = 'instances'
    jsonString = json.dumps(attrDict)
    with open(saveName, 'w') as f:
        f.write(jsonString)


OID2JSON('OID', 'custom_train.json', 'train')
OID2JSON('OID', 'custom_val.json', 'validation')

以下のプログラムを実行し、学習用データを移動する。

import shutil
import glob
import os

source_train_paths = glob.glob(os.path.join('OID\\Dataset', 'train', '**\\'))
source_val_paths = glob.glob(os.path.join('OID\\Dataset', 'validation', '**\\'))
train_path = 'data\\custom\\train2017\\'
val_path = 'data\\custom\\val2017\\'
convert_anno_path = 'data\\custom\\annotations\\'
# ディレクトリ作成
os.makedirs(train_path, exist_ok=True)
os.makedirs(val_path, exist_ok=True)
os.makedirs(convert_anno_path, exist_ok=True)
# train移動
for source_train_path in source_train_paths:
    for img_path in glob.glob(os.path.join(source_train_path, '*.jpg')):
        shutil.move(img_path, train_path)

# val移動
for source_val_path in source_val_paths:
    for img_path in glob.glob(os.path.join(source_val_path, '*.jpg')):
        shutil.move(img_path, val_path)

# anno移動
shutil.move('custom_train.json', convert_anno_path)
shutil.move('custom_val.json', convert_anno_path)

以下のコマンドで必要なライブラリをインストールする。

pip install torch torchvision torchtext -f https://download.pytorch.org/whl/cu118/torch_stable.html
pip install torch torchvision torchtext
pip install matplotlib
pip install pycocotools
pip install scipy

以下のコマンドを実行し、DETRのライブラリを取得する。

# detr_envフォルダ配下で
rd /s detr
git clone https://github.com/woctezuma/detr.git

# ブランチの切替
cd detr/
git checkout finetune

以下のプログラムを実行し、学習済みモデルを保存する。

import torch, torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
import requests

# 学習済みモデルの取得
checkpoint = torch.hub.load_state_dict_from_url(
    url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth',
    map_location='cpu',
    check_hash=True
)

# 分類ヘッドの削除
del checkpoint['model']['class_embed.weight']
del checkpoint['model']['class_embed.bias']

# 保存
torch.save(checkpoint, 'detr-r50_no-class-head.pth')

以下のコマンドで出力先フォルダを作成する。

rd /s outputs
mkdir outputs

以下のコマンドでチューニングを実施する。(制度を向上するなら、epochsを50~100などに変更してもよい。)

python main.py --dataset_file "custom" --coco_path "..\\data\\custom\\" --output_dir "outputs" --resume "detr-r50_no-class-head.pth" --num_classes 5 --epochs 15

以下のプログラムを実行し、学習結果を確認する。

import torch, torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
from pathlib import Path
from io import BytesIO
from PIL import Image
import requests

# %cd /content/detr/
# log_directory = [Path('/content/detr/outputs')]
log_directory = [Path('\\outputs')]

# 実線 ... トレーニング結果(train_loss)
# 破線 ... 検証結果(val_loss)
fields_of_interest = (
    'loss',
    'mAP',
)
# plot_logs(log_directory, fields_of_interest)

finetuned_model = torch.hub.load('facebookresearch/detr',
                       'detr_resnet50',
                       pretrained=False,
                       num_classes=5)
checkpoint = torch.load('detr\\outputs\\checkpoint_epoch.pth',
                        map_location='cpu')
finetuned_model.load_state_dict(checkpoint['model'], strict=False)
finetuned_model.eval()

original_model = torch.hub.load('facebookresearch/detr', 'detr_resnet50_dc5', pretrained=True)
original_model.eval()

# 可視化用クラスラベル
oid_labels = [
  'Apple',
  'Ball',
  'Balloon',
  'Clock',
  'Orange',
]
coco_labels = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]
# 可視化用COLOR
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# 標準的なPyTorchのmean-std入力画像の正規化
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def box_cxcywh_to_xyxy(x):
    """
    (center_x, center_y, width, height)から(xmin, ymin, xmax, ymax)に座標変換
    """
    # unbind(1)でTensor次元を削除
    # (center_x, center_y, width, height)*N → (center_x*N, center_y*N, width*N, height*N)
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    # (center_x, center_y, width, height)*N の形に戻す
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    """
    バウンディングボックスのリスケール
    """
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    # バウンディングボックスの[0~1]から元画像の大きさにリスケール
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def filter_bboxes_from_outputs(outputs, threshold=0.7):
    # 閾値以上の信頼度を持つ予測値のみを保持
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > threshold
    probas_to_keep = probas[keep]
    # [0, 1]のボックスを画像のスケールに変換
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    return probas_to_keep, bboxes_scaled

# 結果の表示
def plot_finetuned_results(pil_img, prob=None, boxes=None, labels=None):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    if prob is not None and boxes is not None:
        for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
            ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin,
                                                                 fill=False, color=c, linewidth=3))
            cl = p.argmax()
            print(labels, p)
            text = f'{labels[cl]}: {p[cl]:0.2f}'
            ax.text(xmin, ymin, text, fontsize=15,
                            bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()

# 物体検出
def run_worflow(my_image, my_model, labels, threshold=0.7):
    # mean-std入力画像の正規化(バッチサイズ : 1)
    img = transform(my_image).unsqueeze(0)
    
    # モデルに反映
    outputs = my_model(img)
    probas_to_keep, bboxes_scaled = filter_bboxes_from_outputs(outputs, threshold=threshold)
    plot_finetuned_results(my_image, probas_to_keep, bboxes_scaled, labels)

url = 'https://farm7.staticflickr.com/52/106887535_a29c34113b_o.jpg'
response = requests.get(url)
im = Image.open(BytesIO(response.content))

# Fine-Tuningモデルで物体検出(閾値0.1) ※15epochだと精度がよくないので、閾値0.1
run_worflow(im, finetuned_model, oid_labels, 0.1)
# Fine-Tuningしていないモデルで物体検出(閾値0.9)
run_worflow(im, original_model, coco_labels, 0.9)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?