LoginSignup
1
1

世界を知ったYOLO YOLOWorld! 【Text2Detction 】

Last updated at Posted at 2024-02-21

テキストで指定すればオブジェクトを検出してくれる

物体検出の手間を大幅に省いてくれる。
今まではデータを集めて学習させる必要があった。
これが丸々いらない。
YoloWorldは、テキストがなにを指すか知っている。

person, guiter, drums, flower, beer, shoes, wall, shorts, denim

bands (1).jpg

しかもT4GPUで0.1秒と高速実行。

使い方

インストール

Colabで使えるようにしたインストール方法

git clone --recursive https://github.com/AILab-CVC/YOLO-World.git
cd YOLO-World
pip install torch==2.0.0 torchvision==0.15.1
pip install torch wheel -q
pip install -U openmim
mim install mmcv==2.0.1
pip install supervision
pip install mmyolo
pip install mmdet

wget https://huggingface.co/wondervictor/YOLO-World/resolve/main/yolo_world_l_clip_base_dual_vlpan_2e-3adamw_32xb16_100e_o365_goldg_cc3mlite_train_pretrained-7a5eea3b.pth

*公式のインストール方法はリポジトリを参照
*今回モデルはyolo_world_lモデルをダウンロードしている。ネットワークサイズの異なる幾つかのモデルが公式リポジトリで用意されているので参照

実行

python image_demo.py configs/pretrain/yolo_world_l_dual_vlpan_l2norm_2e-3_100e_4x8gpus_obj365v1_goldg_train_lvis_val.py yolo_world_l_clip_base_dual_vlpan_2e-3adamw_32xb16_100e_o365_goldg_cc3mlite_train_pretrained-7a5eea3b.pth input_dir 'person,dog,cat' --topk 100 --threshold 0.25 --output-dir out

引数は、順番に
config
checkpoint
input-dir
prompt
max_det
score_threshold
output-dir
です。

Pythonで実行

import cv2
import os
import os.path as osp
import torch

import supervision as sv

from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmengine.runner.amp import autocast
from mmengine.dataset import Compose
from mmengine.utils import ProgressBar
from mmyolo.registry import RUNNERS

BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()


config = "configs/pretrain/yolo_world_l_dual_vlpan_l2norm_2e-3_100e_4x8gpus_obj365v1_goldg_train_lvis_val.py"
checkpoint = "yolo_world_l_clip_base_dual_vlpan_2e-3adamw_32xb16_100e_o365_goldg_cc3mlite_train_pretrained-7a5eea3b.pth"

cfg = Config.fromfile(config)
cfg.work_dir = osp.join('./work_dirs',
                            osp.splitext(osp.basename(config))[0])

cfg.load_from = checkpoint

if 'runner_type' not in cfg:
    runner = Runner.from_cfg(cfg)
else:
    runner = RUNNERS.build(cfg)

output_dir = "out"
image = "vegitables.jpg"
text = "tomato, garlic, onion, vegitable"
score_thr = 0.25
max_dets = 100
use_amp = False

# load text
if text.endswith('.txt'):
    with open(text) as f:
        lines = f.readlines()
    texts = [[t.rstrip('\r\n')] for t in lines] + [[' ']]
else:
    texts = [[t.strip()] for t in text.split(',')] + [[' ']]

output_dir = output_dir
if not osp.exists(output_dir):
    os.mkdir(output_dir)

runner.call_hook('before_run')
runner.load_or_resume()
pipeline = cfg.test_dataloader.dataset.pipeline
runner.pipeline = Compose(pipeline)
runner.model.eval()

if not osp.isfile(image):
    images = [
        osp.join(image, img) for img in os.listdir(image)
        if img.endswith('.png') or img.endswith('.jpg')
    ]
else:
    images = [image]

for image_path in images:

  data_info = dict(img_id=0, img_path=image_path, texts=texts)
  data_info = runner.pipeline(data_info)
  data_batch = dict(inputs=data_info['inputs'].unsqueeze(0),
                    data_samples=[data_info['data_samples']])

  with autocast(enabled=use_amp), torch.no_grad():
      output = runner.model.test_step(data_batch)[0]
      pred_instances = output.pred_instances
      pred_instances = pred_instances[
          pred_instances.scores.float() > score_thr]
  if len(pred_instances.scores) > max_dets:
      indices = pred_instances.scores.float().topk(max_dets)[1]
      pred_instances = pred_instances[indices]

  pred_instances = pred_instances.cpu().numpy()
  detections = sv.Detections(xyxy=pred_instances['bboxes'],
                              class_id=pred_instances['labels'],
                              confidence=pred_instances['scores'])

  labels = [
      f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in
      zip(detections.class_id, detections.confidence)
  ]

  # label images
  image = cv2.imread(image_path)
  image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections)
  image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels)
  cv2.imwrite(osp.join(output_dir, osp.basename(image_path)), image)
tomato, onion, garlic, paprika, vegitables

f6abde47-0ee4-4cb6-91c6-c32f4e5a779f.jpeg

🐣


フリーランスエンジニアです。
AIについて色々記事を書いていますのでよかったらプロフィールを見てみてください。

もし以下のようなご要望をお持ちでしたらお気軽にご相談ください。
AIサービスを開発したい、ビジネスにAIを組み込んで効率化したい、AIを使ったスマホアプリを開発したい、
ARを使ったアプリケーションを作りたい、スマホアプリを作りたいけどどこに相談したらいいかわからない…

いずれも中間コストを省いたリーズナブルな価格でお請けできます。

お仕事のご相談はこちらまで
rockyshikoku@gmail.com

機械学習やAR技術を使ったアプリケーションを作っています。
機械学習/AR関連の情報を発信しています。

X
Medium
GitHub

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1