Vikit-learnは、Pythonを使用して開発されたコンピュータビジョン処理ツールキットで、ディープラーニング技術に基づいています。
このパッケージは、現実のタスクを処理できる一連の使いやすいツールを提供することを目的としています。
このプロジェクトはまだ活発に構築および開発中ですので、どうぞご期待ください!
インストール
依存関係
- Python>=3.10
- matplotlib>=3.7.5
- torch>=2.1.2
- torchvision>=0.16.2
- torchmetrics>=1.3.2
- lightning-utilities>=0.11.2
- faster-coco-eval>=1.5.4
- pycocotools>=2.0.7
pipを使用してインストール
pip install git+https://github.com/bxt-kk/vikit-learn.git
使用方法
モデルのトレーニング
# `pytorch` と `vklearn` をインポート
import torch
from torch.utils.data import DataLoader
from vklearn.trainer.trainer import Trainer
from vklearn.trainer.tasks import Detection
from vklearn.models.trimnetx import TrimNetX as TRBNetX
from vklearn.datasets.oxford_iiit_pet import OxfordIIITPet
device = torch.device('cuda:0')
# TrbnetXモデルを作成
model = TRBNetX(
num_classes=2,
anchors=[(a, a) for a in [21, 63, 189]],
)
# TRBNetXからデフォルトの `transforms` を取得
train_transforms, test_transforms = model.get_transforms()
# データセットを作成
dataset_root = '???/OxfordIIITPet'
dataset_type = 'detection'
train_data = OxfordIIITPet(
dataset_root,
split='trainval',
target_types=dataset_type,
transforms=train_transforms)
test_data = OxfordIIITPet(
dataset_root,
split='trainval',
target_types=dataset_type,
transforms=test_transforms)
# DataLoaderを作成
batch_size = 16
train_loader = DataLoader(
train_data, batch_size, shuffle=True, collate_fn=model.collate_fn, num_workers=4)
test_loader = DataLoader(
test_data, batch_size, shuffle=True, collate_fn=model.collate_fn, num_workers=4)
# オブジェクト検出タスクを構築
task = Detection(
model, device,
loss_options={'weights': dict(conf=0.5, bbox=1.5, clss=0.5)},
)
# トレーニングタスクを指定してトレーナーのパラメータを設定し、トレーナーを構築
trainer = Trainer(
task,
output='/tmp/catdog',
checkpoint=None,
train_loader=train_loader,
test_loader=test_loader,
epochs=5,
lr=1e-3,
show_step=50,
save_epoch=5)
# トレーナーを初期化し、トレーニングを実行
trainer.initialize()
trainer.fit()
トレーニングが完了すると、/tmp/logs/
ディレクトリにモデルのトレーニング結果の可視化画像が保存されます:
私が設計した focal-boost
損失関数に基づいて、極めて低い正のサンプル比率でのタスクでもモデルをうまくトレーニングできます。
モデルの使用
トレーニングされたモデルを以下のようにオブジェクト検出に使用できます:
# `pytorch` と `vklearn` をインポート
import torch
from vklearn.models.trimnetx import TrimNetX
device = torch.device('cpu')
# `state` オブジェクトからモデルをロード
state = torch.load('logs/catdog-4.pt', map_location='cpu')
model:TrimNetX = TrimNetX.load_from_state(state)
model.eval().to(device)
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob
img = Image.open('???/cat.jpeg')
# 結果を検出および表示
with torch.no_grad():
objs = model.detect(img, iou_thresh=0.5, conf_thresh=0.7)
print(len(objs), objs)
fig, ax = plt.subplots()
ax.imshow(img)
for obj in objs:
print('label:', obj['label'])
# if obj['label'] != 0: continue
x1, y1, x2, y2 = obj['box']
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, color='yellow', fill=False))
plt.show()
以下は例です: