Source code for cv2ext.tracking.cv_trackers._boosting

# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import cv2

from cv2ext.tracking._interface import CVTrackerInterface

if TYPE_CHECKING:
    import numpy as np
    from typing_extensions import Self

_log = logging.getLogger(__name__)


[docs] class BoostingTracker(CVTrackerInterface): """A class for tracking objects in videos using the Boosting tracker.""" def __init__(self: Self) -> None: """ Create a new BoostingTracker object. Raises ------ ImportError If Boosting tracker creation function/class cannot be found. """ _log.debug("Creating a legacy tracker (Boosting).") try: try: tracker = cv2.legacy_TrackerBoosting.create() # type: ignore[attr-defined] except AttributeError: tracker = cv2.TrackerBoosting_create() # type: ignore[attr-defined] super().__init__(tracker) except AttributeError as e: err_msg = "Cannot get Boosting tracker from cv2. You may need to install opencv-contrib-python" _log.error(err_msg) raise ImportError from e
[docs] def init(self: Self, image: np.ndarray, bbox: tuple[int, int, int, int]) -> None: """ Initialize the tracker with an image and bounding box. Parameters ---------- image : np.ndarray The image to use for tracking. bbox : tuple[int, int, int, int] The bounding box of the object to track. Bounding box is format (x, y, x, y), where (x, y) is the top-left/bottom-right corner of the box. Raises ------ ValueError If the image is not 3-channel. """ if len(image.shape) == 2 or image.shape[2] == 1: err_msg = "Boosting tracker requires a 3-channel image." raise ValueError(err_msg) super()._init(image, bbox)
[docs] def update(self: Self, image: np.ndarray) -> tuple[bool, tuple[int, int, int, int]]: """ Update the tracker with a new image. Parameters ---------- image : np.ndarray The new image to use for tracking. Returns ------- bool Whether the update was successful. tuple[int, int, int, int] The bounding box of the tracked object. Bounding box is format (x, y, x, y), where (x, y) is the top-left/bottom-right corner of the box. Raises ------ ValueError If the image is not 3-channel. """ if len(image.shape) == 2 or image.shape[2] == 1: err_msg = "Boosting tracker requires a 3-channel image." raise ValueError(err_msg) return super()._update(image)