1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

vikit-learn [vision toolkit]

Posted at

Vikit-learnは、Pythonを使用して開発されたコンピュータビジョン処理ツールキットで、ディープラーニング技術に基づいています。

このパッケージは、現実のタスクを処理できる一連の使いやすいツールを提供することを目的としています。

このプロジェクトはまだ活発に構築および開発中ですので、どうぞご期待ください!

インストール

依存関係

  • Python>=3.10
  • matplotlib>=3.7.5
  • torch>=2.1.2
  • torchvision>=0.16.2
  • torchmetrics>=1.3.2
  • lightning-utilities>=0.11.2
  • faster-coco-eval>=1.5.4
  • pycocotools>=2.0.7

pipを使用してインストール

pip install git+https://github.com/bxt-kk/vikit-learn.git

使用方法

モデルのトレーニング

# `pytorch` と `vklearn` をインポート
import torch
from torch.utils.data import DataLoader
from vklearn.trainer.trainer import Trainer
from vklearn.trainer.tasks import Detection
from vklearn.models.trimnetx import TrimNetX as TRBNetX
from vklearn.datasets.oxford_iiit_pet import OxfordIIITPet


device = torch.device('cuda:0')
# TrbnetXモデルを作成
model = TRBNetX(
    num_classes=2,
    anchors=[(a, a) for a in [21, 63, 189]],
)

# TRBNetXからデフォルトの `transforms` を取得
train_transforms, test_transforms = model.get_transforms()

# データセットを作成
dataset_root = '???/OxfordIIITPet'
dataset_type = 'detection'
train_data = OxfordIIITPet(
    dataset_root,
    split='trainval',
    target_types=dataset_type,
    transforms=train_transforms)
test_data = OxfordIIITPet(
    dataset_root,
    split='trainval',
    target_types=dataset_type,
    transforms=test_transforms)

# DataLoaderを作成
batch_size = 16

train_loader = DataLoader(
    train_data, batch_size, shuffle=True, collate_fn=model.collate_fn, num_workers=4)
test_loader = DataLoader(
    test_data, batch_size, shuffle=True, collate_fn=model.collate_fn, num_workers=4)

# オブジェクト検出タスクを構築
task = Detection(
    model, device,
    loss_options={'weights': dict(conf=0.5, bbox=1.5, clss=0.5)},
)

# トレーニングタスクを指定してトレーナーのパラメータを設定し、トレーナーを構築
trainer = Trainer(
    task,
    output='/tmp/catdog',
    checkpoint=None,
    train_loader=train_loader,
    test_loader=test_loader,
    epochs=5,
    lr=1e-3,
    show_step=50,
    save_epoch=5)

# トレーナーを初期化し、トレーニングを実行
trainer.initialize()
trainer.fit()

トレーニングが完了すると、/tmp/logs/ ディレクトリにモデルのトレーニング結果の可視化画像が保存されます:





私が設計した focal-boost 損失関数に基づいて、極めて低い正のサンプル比率でのタスクでもモデルをうまくトレーニングできます。

モデルの使用

トレーニングされたモデルを以下のようにオブジェクト検出に使用できます:

# `pytorch` と `vklearn` をインポート
import torch
from vklearn.models.trimnetx import TrimNetX


device = torch.device('cpu')
# `state` オブジェクトからモデルをロード
state = torch.load('logs/catdog-4.pt', map_location='cpu')
model:TrimNetX = TrimNetX.load_from_state(state)
model.eval().to(device)

import matplotlib.pyplot as plt
from PIL import Image
from glob import glob

img = Image.open('???/cat.jpeg')
# 結果を検出および表示
with torch.no_grad():
    objs = model.detect(img, iou_thresh=0.5, conf_thresh=0.7)
print(len(objs), objs)
fig, ax = plt.subplots()
ax.imshow(img)
for obj in objs:
    print('label:', obj['label'])
    # if obj['label'] != 0: continue
    x1, y1, x2, y2 = obj['box']
    ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, color='yellow', fill=False))
plt.show()

以下は例です:







1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?