1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【Object Detection】Pytorch Faster-RCNN modelを使ってFashion Imageの領域を検出

Posted at

概要

Fashionデータを利用して帽子、トップス、ボトムスなどのオブジェクトなどを検出しました。

データをロード

csvファイルの中にイメージURL、種類、領域などの情報が含まれている。
この情報をロードして視覚化して確認してみる。

import os
import json
import pandas as pd
import numpy as np
from urllib.parse import *
from requests.utils import requote_uri
import requests
from tqdm.notebook import tqdm
BASE_DIR = os.getcwd()
annotations_path = os.path.join(BASE_DIR, 'annotations')
images_path = os.path.join(BASE_DIR, 'images')
test_annotations_path = os.path.join(BASE_DIR, 'test_annotations')
test_images_path = os.path.join(BASE_DIR, 'test_images')
dirs = [annotations_path, images_path, test_annotations_path, test_images_path]
list(map(lambda x : os.makedirs(x, exist_ok=True), dirs))
origin_labels = pd.read_csv('sample.csv')
origin_labels.head()

一行は帽子の情報、二行はズボンの情報が含まれていることを確認した。

Unamed filename width height class xmin ymin xmax ymax
0 0 https://sample.com/image_source/test1.jpg 591 1137 Hat 247 61 474 284
1 1 https://sample.com/image_source/test2.jpg 591 1137 Pants 219 367 590 848

種類についてグルーピングしたら帽子、パンツなど8種類がある。

classes = origin_labels.groupby('class')
class_indexes = list(classes.groups.keys())
print(class_indexes)
class_indexes.index('Pants')
['Hat', 'Pants', 'One Piece', 'T-Shirt', 'Shoes', 'Jaket', 'Skirt', 'Coat']
1

データ前処理

一つのイメージに様々な領域、種類があるので、イメージURL情報にグルーピングをした。

def get_bbox(group):
    new_bbox = []
    for item in group.values:
        new_bbox.append((class_indexes.index(item[4]), [float(item[5]), float(item[6]), float(item[7]), float(item[8])]))
    return np.asarray(new_bbox)

labelsはイメージURLをキーにしてディレクトリ形になっており、値は[(イメージ種類、ボックス領域)、(イメージ種類、ボックス領域)]型のデータ構造になっている。

labels = origin_labels.groupby('filename').apply(get_bbox)
print(labels.keys())
print(labels[labels.keys()[0]])
['https://sample.com/image_source/test1.jpg']
[(0, [247, 61, 474, 284]), (1, [0, 10, 50, 100])]

データから生成する訓練データ、評価データを8:2の割合で分割する。

total_values = np.unique(labels.index.values)
total_len = len(total_values)

train_ratio = int(total_len * 0.8)
train_image_ids = total_values[0:train_ratio]
val_image_ids = total_values[train_ratio:]

イメージでーたをダウンロードしてファイル化する。

def parse_url(url):
    p = urlparse(url, 'http')
    if p.netloc:
        netloc = p.netloc
        path = p.path
    else:
        netloc = p.path
        path = ''
    p = p._replace(netloc=netloc, path=path)
    return p.geturl()
def download_image_from_url(image_url, source='train'):  
    try:
        target_url = parse_url(image_url)      
        response = requests.get(requote_uri(target_url), stream=True)
        filename = image_url.split('/')[-1] 
        if source == 'train':
            filepath = os.path.join(images_path, filename)
        else:
            filepath = os.path.join(test_images_path, filename)
        
        if os.path.exists(filepath):
            return True
            
        if response.status_code == 200:
            with open(filepath, 'wb') as f:
                f.write(response.content)
                return True        
    except Exception as e:
        pass
    return False
def write_annotation(image_url, source='train'):
    basename = image_url.split('/')[-1] 
    basename, _ = os.path.splitext(basename)
    filename = basename + '.json'
    
    if source == 'train':
        filepath = os.path.join(annotations_path, filename)
    else:
        filepath = os.path.join(test_annotations_path, filename)            
        
    if os.path.exists(filepath):
        return True        
        
    annotations = labels[image_url]    
    new_annotations = []
    for item in annotations:   
        new_annotations.append({
            'class_index': item[0],
            'annotation' : item[1]
        })
    with open(filepath, 'w') as f:
        f.write(json.dumps(new_annotations))
make_dataset(train_image_ids, source='train')

image.png

make_dataset(val_image_ids, source='test')

image.png

#データパイプライン
すべてのデータを取得してメモリが負荷することがある。そのために訓練時バッチサイズだけ取得してデータを加工する。

import numpy as np
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from bs4 import BeautifulSoup
from PIL import Image
import torchvision
from torchvision import transforms, datasets, models
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import time
import torch
def generate_target(file): 
    with open(file, encoding='utf-8') as f:
        data = json.loads(f.read())
        boxes = []
        labels = []
        for item in data:
            boxes.append(item['annotation'])
            labels.append(item['class_index'])

        boxes = torch.as_tensor(boxes, dtype=torch.float32) 
        labels = torch.as_tensor(labels, dtype=torch.int64) 
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        
        return target
class Dataset(object):
    def __init__(self, transforms, path):
        self.transforms = transforms
        self.path = os.path.join(BASE_DIR, path)
        self.imgs = list(sorted(os.listdir(self.path)))

    def __getitem__(self, idx): #special method
        file_image = self.imgs[idx]
        file_label = self.imgs[idx][:-3] + 'json'
        img_path = os.path.join(self.path, file_image)
        
        if 'test' in self.path:
            label_path = os.path.join(BASE_DIR, "test_annotations/", file_label)
        else:
            label_path = os.path.join(BASE_DIR, "annotations/", file_label)

        img = Image.open(img_path).convert("RGB")
        target = generate_target(label_path)
        
        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self): 
        return len(self.imgs)

data_transform = transforms.Compose([transforms.ToTensor()])

def collate_fn(batch):
    return tuple(zip(*batch))

訓練および評価データパイプラインを生成する。

dataset = Dataset(data_transform, 'images')
test_dataset = Dataset(data_transform, 'test_images')
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=5, collate_fn=collate_fn)

#モデル設定

def get_model_instance_segmentation(num_classes):  
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

Faster-RCNN モデルをロードする。

model = get_model_instance_segmentation(len(classes))
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 
model.to(device)
FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
...<省略>
torch.cuda.is_available()
True
num_epochs = 100
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                                momentum=0.9, weight_decay=0.0005)

#学習(訓練)
訓練を実行する。

for epoch in range(num_epochs):
    start = time.time()
    model.train()
    i = 0    
    epoch_loss = 0
    for imgs, annotations in data_loader:
        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations) 
        losses = sum(loss for loss in loss_dict.values())        

        optimizer.zero_grad()
        losses.backward()
        optimizer.step() 
        epoch_loss += losses
    print(f'epoch : {epoch+1}, Loss : {epoch_loss}, time : {time.time() - start}')
epoch : 1, Loss : 246.2340850830078, time : 283.55497002601624
epoch : 2, Loss : 202.09490966796875, time : 284.3149654865265
epoch : 3, Loss : 167.4278564453125, time : 284.43429923057556
epoch : 4, Loss : 147.18411254882812, time : 284.58269715309143
epoch : 5, Loss : 135.09730529785156, time : 284.5432515144348
...<省略>

モデルを保存する。

torch.save(model.state_dict(), os.path.join(BASE_DIR, f'model_{num_epochs}.pt'))

#予測
保存したモデルをロードする。

model.load_state_dict(torch.load(f'model_{num_epochs}.pt'))
def make_prediction(model, img, threshold):
    model.eval()
    preds = model(img)
    for id in range(len(preds)) :
        idx_list = []
        for idx, score in enumerate(preds[id]['scores']) :
            if score > threshold :
                idx_list.append(idx)
        preds[id]['boxes'] = preds[id]['boxes'][idx_list]
        preds[id]['labels'] = preds[id]['labels'][idx_list]
        preds[id]['scores'] = preds[id]['scores'][idx_list]
    return preds
cmap = plt.get_cmap('jet', len(class_indexes))
colors = [cmap(i) for i in np.linspace(0, 1, 50)]

def plot_image_from_output(img, annotation):    
    img = img.cpu().permute(1,2,0)
    
    fig,ax = plt.subplots(1)
    ax.imshow(img)    
    for idx in range(len(annotation["boxes"])):
        xmin, ymin, xmax, ymax = annotation["boxes"][idx].cpu()
        class_index = annotation['labels'][idx].cpu().numpy() 
        class_name = class_indexes[class_index]
        color = colors[class_index]        
        rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin), linewidth=1, edgecolor=color, facecolor='none', label=class_name)        
        ax.add_patch(rect)
        plt.text(xmin, ymin, s=class_name, 
            color='white', verticalalignment='top', bbox={'color': color, 'pad': 0})

    plt.show()
with torch.no_grad(): 
    for imgs, annotations in test_data_loader:
        imgs = list(img.to(device) for img in imgs)
        pred = make_prediction(model, imgs, 0.5)
        print(pred)
        break
_idx = 0
print("Target : ", annotations[_idx]['labels'])
plot_image_from_output(imgs[_idx], annotations[_idx])
print("Prediction : ", pred[_idx]['labels'])
plot_image_from_output(imgs[_idx], pred[_idx])

image.png

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?