"""Object detection wrapper using YOLOv8 for cat detection.
This module provides a CatDetector class that wraps the Ultralytics YOLOv8
model for detecting cats in video frames. It uses the pre-trained YOLOv8n
(nano) model optimized for CPU inference on Raspberry Pi.
The detector filters detections to only return cats (COCO class ID 15) above
a configurable confidence threshold.
Example:
Basic usage::
detector = CatDetector()
detections = detector.detect_cats(frame, threshold=0.5)
for det in detections:
print(f"Cat detected with confidence {det.confidence}")
"""
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple
import cv2
import numpy as np
from ultralytics import YOLO
import config
logger = logging.getLogger(__name__)
[docs]
@dataclass
class Detection:
"""Represents a detected cat with bounding box and confidence.
Attributes:
label: Object class label (always "cat" for this detector).
confidence: Detection confidence score (0.0 to 1.0).
bbox: Bounding box coordinates as (x1, y1, x2, y2) where:
- x1, y1: Top-left corner coordinates
- x2, y2: Bottom-right corner coordinates
All coordinates are in pixel values relative to the frame.
Example:
>>> det = Detection(label="cat", confidence=0.85, bbox=(100, 50, 200, 150))
>>> x1, y1, x2, y2 = det.bbox
>>> print(f"Cat at ({x1}, {y1}) to ({x2}, {y2}) with {det.confidence:.0%} confidence")
"""
label: str
confidence: float
bbox: Tuple[float, float, float, float]
[docs]
class CatDetector:
"""Wrapper for YOLOv8 model optimized for cat detection on CPU.
Uses the Ultralytics YOLOv8n (nano) model, which is lightweight enough
for real-time inference on Raspberry Pi CPU. Filters detections to only
return cats (COCO class ID 15) above the confidence threshold.
Attributes:
model: Loaded YOLOv8 model instance.
imgsz: Input image size for the model (from config, typically 640).
Example:
>>> detector = CatDetector()
>>> detections = detector.detect_cats(frame, threshold=0.5)
>>> if detections:
... print(f"Found {len(detections)} cat(s)")
"""
[docs]
def __init__(self, model_path: str = None, device: str = "cpu"):
"""
Initialize the detector with YOLOv8 model.
Args:
model_path: Path to YOLO model file (will download if not exists).
If None, uses value from config file.
device: Device to run inference on ('cpu' or 'cuda')
"""
if model_path is None:
model_path = config.get_model_name()
logger.info(f"Loading YOLO model: {model_path}")
try:
self.model = YOLO(model_path)
self.model.to(device)
# Set model to use smaller input size for better CPU performance
# This can be adjusted: smaller = faster but less accurate
self.imgsz = config.get_model_size()
logger.info(f"Model loaded successfully. Using device: {device}, input size: {self.imgsz}")
except Exception as e:
logger.error(f"Failed to load detection model: {e}")
raise RuntimeError(f"Failed to load detection model: {e}")
[docs]
def detect_cats(
self, frame: np.ndarray, threshold: float = None
) -> List[Detection]:
"""
Detect cats in a frame.
Args:
frame: BGR image as numpy array (OpenCV format)
threshold: Minimum confidence threshold for detections.
If None, uses value from config file.
Returns:
List of Detection objects for cats above the threshold
"""
if frame is None or frame.size == 0:
return []
if threshold is None:
threshold = config.get_detection_threshold()
try:
# Run inference
# Note: imgsz parameter controls input size - smaller values run faster
# on CPU but may reduce accuracy. Adjust model_size in config.yaml
# if you need better performance (e.g., 416) or accuracy (e.g., 640).
results = self.model(frame, imgsz=self.imgsz, conf=threshold, device="cpu", verbose=False)
detections: List[Detection] = []
cat_class_id = config.get_cat_class_id()
for result in results:
boxes = result.boxes
if boxes is None:
continue
for box in boxes:
# Check if this is a cat detection
# COCO class 15 is "cat"
class_id = int(box.cls[0])
if class_id == cat_class_id:
confidence = float(box.conf[0])
if confidence >= threshold:
# Get bounding box coordinates
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
detections.append(
Detection(
label="cat",
confidence=confidence,
bbox=(float(x1), float(y1), float(x2), float(y2)),
)
)
return detections
except Exception as e:
logger.error(f"Error during detection: {e}")
return []