LoginSignup
6
11

More than 3 years have passed since last update.

darknetでyoloのデータセット作成を半自動化する。

Last updated at Posted at 2019-06-02

データセット作るのめんどくさい人

私です。
データセットを作っていると心が荒んでいくのを感じます。

というわけで

darknetを使ってyoloのフォーマットで検出結果をテキストに出力します。
yoloのフォーマットとは

クラスid BoundingBoxのxの中心座標 BoundingBoxのyの中心座標 BoundingBoxの幅 BoundingBoxの高さ

です。
txtファイルで保存されます。

使い方

darknetをクローンする。

git clone https://github.com/pjreddie/darknet.git

そしてmake
必要に応じてMakefileでGPU、cudnnを有効化してください。

cd darknet
make

そしてpythonフォルダに以下のコードを入れてください。

make_dataset.py
from ctypes import *
import math
import random
import sys

def sample(probs):
    s = sum(probs)
    probs = [a/s for a in probs]
    r = random.uniform(0, 1)
    for i in range(len(probs)):
        r = r - probs[i]
        if r <= 0:
            return i
    return len(probs)-1

def c_array(ctype, values):
    arr = (ctype*len(values))()
    arr[:] = values
    return arr

class BOX(Structure):
    _fields_ = [("x", c_float),
                ("y", c_float),
                ("w", c_float),
                ("h", c_float)]

class DETECTION(Structure):
    _fields_ = [("bbox", BOX),
                ("classes", c_int),
                ("prob", POINTER(c_float)),
                ("mask", POINTER(c_float)),
                ("objectness", c_float),
                ("sort_class", c_int)]


class IMAGE(Structure):
    _fields_ = [("w", c_int),
                ("h", c_int),
                ("c", c_int),
                ("data", POINTER(c_float))]

class METADATA(Structure):
    _fields_ = [("classes", c_int),
                ("names", POINTER(c_char_p))]



lib = CDLL("libdarknet.so", RTLD_GLOBAL)
#lib = CDLL("libdarknet.so", RTLD_GLOBAL)
lib.network_width.argtypes = [c_void_p]
lib.network_width.restype = c_int
lib.network_height.argtypes = [c_void_p]
lib.network_height.restype = c_int

predict = lib.network_predict
predict.argtypes = [c_void_p, POINTER(c_float)]
predict.restype = POINTER(c_float)

set_gpu = lib.cuda_set_device
set_gpu.argtypes = [c_int]

make_image = lib.make_image
make_image.argtypes = [c_int, c_int, c_int]
make_image.restype = IMAGE

get_network_boxes = lib.get_network_boxes
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)]
get_network_boxes.restype = POINTER(DETECTION)

make_network_boxes = lib.make_network_boxes
make_network_boxes.argtypes = [c_void_p]
make_network_boxes.restype = POINTER(DETECTION)

free_detections = lib.free_detections
free_detections.argtypes = [POINTER(DETECTION), c_int]

free_ptrs = lib.free_ptrs
free_ptrs.argtypes = [POINTER(c_void_p), c_int]

network_predict = lib.network_predict
network_predict.argtypes = [c_void_p, POINTER(c_float)]

reset_rnn = lib.reset_rnn
reset_rnn.argtypes = [c_void_p]

load_net = lib.load_network
load_net.argtypes = [c_char_p, c_char_p, c_int]
load_net.restype = c_void_p

do_nms_obj = lib.do_nms_obj
do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]

do_nms_sort = lib.do_nms_sort
do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]

free_image = lib.free_image
free_image.argtypes = [IMAGE]

letterbox_image = lib.letterbox_image
letterbox_image.argtypes = [IMAGE, c_int, c_int]
letterbox_image.restype = IMAGE

load_meta = lib.get_metadata
lib.get_metadata.argtypes = [c_char_p]
lib.get_metadata.restype = METADATA

load_image = lib.load_image_color
load_image.argtypes = [c_char_p, c_int, c_int]
load_image.restype = IMAGE

rgbgr_image = lib.rgbgr_image
rgbgr_image.argtypes = [IMAGE]

predict_image = lib.network_predict_image
predict_image.argtypes = [c_void_p, IMAGE]
predict_image.restype = POINTER(c_float)

def classify(net, meta, im):
    out = predict_image(net, im)
    res = []
    for i in range(meta.classes):
        res.append((meta.names[i], out[i]))
    res = sorted(res, key=lambda x: -x[1])
    return res

def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
    im = load_image(image, 0, 0)
    num = c_int(0)
    pnum = pointer(num)
    predict_image(net, im)
    dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
    num = pnum[0]
    if (nms): do_nms_obj(dets, num, meta.classes, nms);

    res = []
    for j in range(num):
        for i in range(meta.classes):
            if dets[j].prob[i] > 0:
                b = dets[j].bbox
                res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h)))
    res = sorted(res, key=lambda x: -x[1])
    free_image(im)
    free_detections(dets, num)
    return res

def usage():
    print("Usage: python python/make_dataset.py .cfg .weights .data folder")

if __name__ == "__main__":
    args = sys.argv
    if args[1] == "-h" or args[1] == "--help":
        usage()        
        sys.exit()
    net = load_net(args[1], args[2], 0)
    meta = load_meta(args[3])
    import glob,os,cv2
    files = sorted(glob.glob(args[4] + '*.jpg'))
    i = 0
    for pic in files:
        picture = cv2.imread(pic)
        r = detect(net, meta, pic)
        print pic
        path = pic[:-3] + "txt"
        f = open(path, 'w')
        for result in r:
            picture_height = picture.shape[0]
            picture_width = picture.shape[1]
            x = result[2][0] / picture_width
            y = result[2][1] / picture_height
            w = result[2][2] /  picture_width
            h = result[2][3] /  picture_height
            offset = 0
            for name in meta.names:
                if result[0] == name:
                    break
                offset += 1
            """
            print("result: " + str(result[0]) + " = " + str(offset))
            print("x: " + str(x))
            print("y: " + str(y))
            print("w: " + str(w))
            print("h: " + str(h))
            """
            f.write(str(offset) + " " + str(x) + " " + str(y) + " " + str(w) + " " + str(h) + "\n")
        f.close()

そしてdarknetのディレクトリから

python python/make_dataset.py .cfgファイルのパス 重みのパス .dataファイルのパス 画像のフォルダ

を起動すると画像のフォルダの中に検出結果が保存されます。
結果を見るならlabelImg(https://github.com/tzutalin/labelImg)
などで画像フォルダをみると検出結果がアノテーションデータとして保存されているのがわかります。

アノテーションは辛い

データセット作成→学習→このコードで検出結果の書き出し→labelImgで編集→学習を繰り返せば、効率upにつながると思います。
アノテーションは辛いのでこれで誰かを心の荒みから救えたら幸いです。
質問あったらコメント欄まで。

参考サイト

darknet
https://github.com/pjreddie/darknet
labelImg
https://github.com/tzutalin/labelImg

6
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
11