0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

1) Train a Lightweight Pet-Breed Classifier with PyTorch + TIMM (Mobile-Friendly)

Posted at

Summary: Huấn luyện mô hình phân loại giống chó/mèo nhẹ (≤10–20M tham số) bằng timm, tối ưu cho inference CPU/thiết bị yếu. Xuất ONNX để triển khai.

Mục lục

Vì sao cần “lightweight” cho e-commerce thú cưng

Chuẩn bị dữ liệu (cấu trúc thư mục, split, class balance)

Chọn backbone từ timm + augmentations

Training loop (mixed precision), metric (macro-F1)

Export ONNX + test nhanh trên CPU

Lưu ý thường gặp

Further reading

  1. Vì sao cần “lightweight”

Ảnh thú cưng đa dạng tư thế/nền; người dùng hay truy cập bằng điện thoại.

Model nhỏ → thời gian dự đoán nhanh, phù hợp CPU, cải thiện UX và SEO Core Web Vitals (LCP/INP gián tiếp).

Mục tiêu tham khảo: <80ms/ảnh trên CPU desktop, và tốc độ “chấp nhận được” trên máy cấu hình thấp.

  1. Chuẩn bị dữ liệu

Cấu trúc tối thiểu:

data/
train/
corgi/
poodle/
pomeranian/
...
val/
corgi/
poodle/
pomeranian/
...

Cân bằng lớp: nếu lệch, dùng class weights hoặc sampling.

Chuẩn hóa side-length (224–256) để ổn định tốc độ.

  1. Backbone + Augmentations

Gợi ý timm backbone: efficientnet_v2_s, mobilenetv3_large_100, convnext_tiny (nhẹ + hiệu quả).

import timm, torch, torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.optim import AdamW

N_CLASSES = 20 # ví dụ
model = timm.create_model("efficientnet_v2_s", pretrained=True, num_classes=N_CLASSES)

train_tfms = transforms.Compose([
transforms.Resize((256,256)),
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.1,0.1,0.1,0.05),
transforms.ToTensor()
])

val_tfms = transforms.Compose([
transforms.Resize((256,256)),
transforms.CenterCrop(224),
transforms.ToTensor()
])

train_ds = datasets.ImageFolder("data/train", transform=train_tfms)
val_ds = datasets.ImageFolder("data/val", transform=val_tfms)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)

  1. Training loop + Macro-F1
    from sklearn.metrics import f1_score
    import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

def evaluate():
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
pred = logits.argmax(1)
y_true.extend(y.cpu().numpy())
y_pred.extend(pred.cpu().numpy())
return f1_score(y_true, y_pred, average="macro")

best_f1 = 0.0
for epoch in range(10):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=(device=="cuda")):
logits = model(x)
loss = criterion(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
f1 = evaluate()
print(f"Epoch {epoch}: macro-F1={f1:.4f}")
if f1 > best_f1:
best_f1 = f1
torch.save(model.state_dict(), "best.pt")

  1. Export ONNX + test CPU
    model.load_state_dict(torch.load("best.pt", map_location=device))
    model.eval()
    dummy = torch.randn(1,3,224,224).to(device)
    torch.onnx.export(model, dummy, "petbreed.onnx",
    input_names=["input"], output_names=["logits"],
    opset_version=17)

Quick test với onnxruntime:

import onnxruntime as ort, numpy as np
import cv2, time

sess = ort.InferenceSession("petbreed.onnx", providers=["CPUExecutionProvider"])
def run(img_path):
img = cv2.imread(img_path)[:,:,::-1]
img = cv2.resize(img, (224,224)).astype(np.float32)/255.0
x = np.transpose(img, (2,0,1))[None]
t0 = time.time()
logits = sess.run(None, {"input": x})[0]
dt = (time.time()-t0)*1000
prob = np.exp(logits - logits.max())
prob = prob/prob.sum()
top = prob[0].argmax()
return top, dt

idx, ms = run("sample.jpg")
print("pred:", idx, "time(ms):", round(ms,2))

  1. Lưu ý

Tránh leakage: không để ảnh tương tự của cùng cá thể xuất hiện cả train và val.

Đa dạng nền/ánh sáng; nếu shop của bạn có studio nền trắng, hãy augment để tránh overfit.

Macro-F1 phản ánh cân bằng giữa các giống phổ biến/hiếm.

  1. Further reading

TIMM docs

ONNX Runtime docs

Ví dụ demo Gradio (tham khảo ở bài 2)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?