Summary: Huấn luyện mô hình phân loại giống chó/mèo nhẹ (≤10–20M tham số) bằng timm, tối ưu cho inference CPU/thiết bị yếu. Xuất ONNX để triển khai.
Mục lục
Vì sao cần “lightweight” cho e-commerce thú cưng
Chuẩn bị dữ liệu (cấu trúc thư mục, split, class balance)
Chọn backbone từ timm + augmentations
Training loop (mixed precision), metric (macro-F1)
Export ONNX + test nhanh trên CPU
Lưu ý thường gặp
Further reading
- Vì sao cần “lightweight”
Ảnh thú cưng đa dạng tư thế/nền; người dùng hay truy cập bằng điện thoại.
Model nhỏ → thời gian dự đoán nhanh, phù hợp CPU, cải thiện UX và SEO Core Web Vitals (LCP/INP gián tiếp).
Mục tiêu tham khảo: <80ms/ảnh trên CPU desktop, và tốc độ “chấp nhận được” trên máy cấu hình thấp.
- Chuẩn bị dữ liệu
Cấu trúc tối thiểu:
data/
train/
corgi/
poodle/
pomeranian/
...
val/
corgi/
poodle/
pomeranian/
...
Cân bằng lớp: nếu lệch, dùng class weights hoặc sampling.
Chuẩn hóa side-length (224–256) để ổn định tốc độ.
- Backbone + Augmentations
Gợi ý timm backbone: efficientnet_v2_s, mobilenetv3_large_100, convnext_tiny (nhẹ + hiệu quả).
import timm, torch, torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.optim import AdamW
N_CLASSES = 20 # ví dụ
model = timm.create_model("efficientnet_v2_s", pretrained=True, num_classes=N_CLASSES)
train_tfms = transforms.Compose([
transforms.Resize((256,256)),
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.1,0.1,0.1,0.05),
transforms.ToTensor()
])
val_tfms = transforms.Compose([
transforms.Resize((256,256)),
transforms.CenterCrop(224),
transforms.ToTensor()
])
train_ds = datasets.ImageFolder("data/train", transform=train_tfms)
val_ds = datasets.ImageFolder("data/val", transform=val_tfms)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)
- Training loop + Macro-F1
from sklearn.metrics import f1_score
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
def evaluate():
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
pred = logits.argmax(1)
y_true.extend(y.cpu().numpy())
y_pred.extend(pred.cpu().numpy())
return f1_score(y_true, y_pred, average="macro")
best_f1 = 0.0
for epoch in range(10):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=(device=="cuda")):
logits = model(x)
loss = criterion(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
f1 = evaluate()
print(f"Epoch {epoch}: macro-F1={f1:.4f}")
if f1 > best_f1:
best_f1 = f1
torch.save(model.state_dict(), "best.pt")
- Export ONNX + test CPU
model.load_state_dict(torch.load("best.pt", map_location=device))
model.eval()
dummy = torch.randn(1,3,224,224).to(device)
torch.onnx.export(model, dummy, "petbreed.onnx",
input_names=["input"], output_names=["logits"],
opset_version=17)
Quick test với onnxruntime:
import onnxruntime as ort, numpy as np
import cv2, time
sess = ort.InferenceSession("petbreed.onnx", providers=["CPUExecutionProvider"])
def run(img_path):
img = cv2.imread(img_path)[:,:,::-1]
img = cv2.resize(img, (224,224)).astype(np.float32)/255.0
x = np.transpose(img, (2,0,1))[None]
t0 = time.time()
logits = sess.run(None, {"input": x})[0]
dt = (time.time()-t0)*1000
prob = np.exp(logits - logits.max())
prob = prob/prob.sum()
top = prob[0].argmax()
return top, dt
idx, ms = run("sample.jpg")
print("pred:", idx, "time(ms):", round(ms,2))
- Lưu ý
Tránh leakage: không để ảnh tương tự của cùng cá thể xuất hiện cả train và val.
Đa dạng nền/ánh sáng; nếu shop của bạn có studio nền trắng, hãy augment để tránh overfit.
Macro-F1 phản ánh cân bằng giữa các giống phổ biến/hiếm.
- Further reading
TIMM docs
ONNX Runtime docs
Ví dụ demo Gradio (tham khảo ở bài 2)