Source code for cv2ext.tracking._multi_tracker

# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
from __future__ import annotations

import atexit
import contextlib
from queue import Empty, Queue
from threading import Event, Thread
from typing import TYPE_CHECKING

from ._interface import AbstractMultiTracker, AbstractTracker
from ._tracker_type import TrackerType

if TYPE_CHECKING:
    import numpy as np
    from typing_extensions import Self


[docs] class MultiTracker(AbstractMultiTracker): """Handles multiple trackers for tracking multiple objects in a video.""" def __init__( self: Self, tracker_type: TrackerType | type[AbstractTracker] = TrackerType.KCF, *, use_threads: bool | None = None, ) -> None: """ Create a new MultiTracker object. Parameters ---------- tracker_type : TrackerType | type[AbstractTracker] The type of tracker to use for tracking objects. use_threads : bool, optional Whether to use threading for tracking, by default None. If None, the tracker will use threading. """ if use_threads is None: use_threads = True self._tracker: AbstractMultiTracker = ( _ThreadedMultiTracker(tracker_type) if use_threads else _SerialMultiTracker(tracker_type) )
[docs] def init( self: Self, image: np.ndarray, bboxes: list[tuple[int, int, int, int]], ) -> None: """ Initialize the trackers with the initial bounding boxes. Parameters ---------- image : np.ndarray The first frame of the video. bboxes : list[tuple[int, int, int, int]] The initial bounding boxes of the targets. Each bbox is represented as (x1, y1, x2, y2). """ self._tracker.init(image, bboxes)
[docs] def update( self: Self, image: np.ndarray, ) -> list[tuple[bool, tuple[int, int, int, int]]]: """ Update the trackers with the next frame of the video. Parameters ---------- image : np.ndarray The next frame of the video. Returns ------- list[tuple[bool, tuple[int, int, int, int]]] The updated success values and bounding boxes of the targets. Each bbox is represented as (x1, y1, x2, y2). """ return self._tracker.update(image)
class _ThreadedMultiTracker(AbstractMultiTracker): """Threaded version of MultiTracker.""" def __init__(self: Self, tracker_type: TrackerType | type[AbstractTracker]) -> None: """ Create a new ThreadedMultiTracker object. Parameters ---------- tracker_type : TrackerType | type[AbstractTracker] The type of tracker to use for tracking objects. """ self._tracker_type = ( tracker_type.value if isinstance(tracker_type, TrackerType) else tracker_type ) self._trackers: list[AbstractTracker] = [] self._stop_event = Event() self._threads: list[Thread] = [] self._in_queues: list[Queue[np.ndarray]] = [] self._out_queues: list[Queue[tuple[bool, tuple[int, int, int, int]]]] = [] atexit.register(self._reset) def _reset(self: Self) -> None: self._stop_event.set() for thread in self._threads: thread.join() self._stop_event.clear() self._trackers = [] self._in_queues = [] self._out_queues = [] self._threads = [] def _thread_target( self: Self, initial_image: np.ndarray, initial_bbox: tuple[int, int, int, int], thread_id: int, ) -> None: tracker = self._tracker_type() tracker.init(initial_image, initial_bbox) while not self._stop_event.is_set(): with contextlib.suppress(Empty): image = self._in_queues[thread_id].get(timeout=0.1) data = tracker.update(image) self._out_queues[thread_id].put_nowait(data) def init( self: Self, image: np.ndarray, bboxes: list[tuple[int, int, int, int]], ) -> None: """ Initialize the trackers with the initial bounding boxes. Parameters ---------- image : np.ndarray The first frame of the video. bboxes : list[tuple[int, int, int, int]] The initial bounding boxes of the targets. Each bbox is represented as (x1, y1, x2, y2). """ self._reset() for idx, bbox in enumerate(bboxes): in_queue: Queue[np.ndarray] = Queue() out_queue: Queue[tuple[bool, tuple[int, int, int, int]]] = Queue() self._in_queues.append(in_queue) self._out_queues.append(out_queue) thread = Thread( target=self._thread_target, args=(image, bbox, idx), daemon=True, ) thread.start() self._threads.append(thread) def update( self: Self, image: np.ndarray, ) -> list[tuple[bool, tuple[int, int, int, int]]]: """ Update the trackers with the next frame of the video. Parameters ---------- image : np.ndarray The next frame of the video. Returns ------- list[tuple[bool, tuple[int, int, int, int]]] The updated success values and bounding boxes of the targets. Each bbox is represented as (x1, y1, x2, y2). Raises ------ RuntimeError If a thread has been closed before an update call. RuntimeError If a queue times out fetching an update. """ for in_queue, thread in zip(self._in_queues, self._threads): if not thread.is_alive(): err_msg = "Thread closed, cannot update tracking." raise RuntimeError(err_msg) in_queue.put(image) try: return [out_queue.get(timeout=1.0) for out_queue in self._out_queues] except Empty as e: err_msg = "Could not get results from a thread, most likely crashed." raise RuntimeError(err_msg) from e class _SerialMultiTracker(AbstractMultiTracker): """Serial version of MultiTracker.""" def __init__(self: Self, tracker_type: TrackerType | type[AbstractTracker]) -> None: """ Create a new SerialMultiTracker object. Parameters ---------- tracker_type : TrackerType | type[AbstractTracker] The type of tracker to use for tracking objects. """ self._tracker_type = ( tracker_type.value if isinstance(tracker_type, TrackerType) else tracker_type ) self._trackers: list[AbstractTracker] = [] def init( self: Self, image: np.ndarray, bboxes: list[tuple[int, int, int, int]], ) -> None: """ Initialize the trackers with the initial bounding boxes. Parameters ---------- image : np.ndarray The first frame of the video. bboxes : list[tuple[int, int, int, int]] The initial bounding boxes of the targets. Each bbox is represented as (x1, y1, x2, y2). """ self._trackers = [self._tracker_type() for _ in bboxes] for tracker, bbox in zip(self._trackers, bboxes): tracker.init(image, bbox) def update( self: Self, image: np.ndarray, ) -> list[tuple[bool, tuple[int, int, int, int]]]: """ Update the trackers with the next frame of the video. Parameters ---------- image : np.ndarray The next frame of the video. Returns ------- list[tuple[bool, tuple[int, int, int, int]]] The updated success values and bounding boxes of the targets. Each bbox is represented as (x1, y1, x2, y2). """ return [tracker.update(image) for tracker in self._trackers]