Source code for duck.utils.threading.thread_manager

"""
`WorkerThreadManager` manages and monitors a pool of worker threads.
    
**Features:**
- Automatic restart of dead or unhealthy workers.
- Customizable health-check hooks per worker.
- Threaded non-blocking monitoring loop.
- Configurable logging and verbosity.
- Status inspection and graceful shutdown.
    
Example use cases:
- WSGI/ASGI server worker orchestration.
- ML/AI multi-threaded task runner watchdog.
- Long-running web backend with thread self-repair.

Usage Example:
```py
def sample_worker(idx, *args):
    import time, random
    print(f"Worker {idx} started...")
    
    while True:
        time.sleep(1)
        print(f"[Worker {idx}] Sleeping")
        
        # Simulate random crash; restart will occur
        if random.random() < 0.03:
            print(f"[Worker {idx}] Simulating crash")
            exit(1)

def health_check_fn(thread, idx):
    # Returns True if alive; override for custom checks
    return thread.is_alive()

manager = WorkerThreadManager(
    worker_fn=sample_worker,
    num_workers=4,
    args_fn=lambda idx: (...),
    worker_name_fn=lambda idx: f"duck-worker-{idx}",
    health_check_fn=health_check_fn, # Or use HeartbeatHeathCheck object.
    restart_timeout=2,
    enable_logs=True, verbose_logs=False,
    enable_monitoring=True,
    thread_stop_timeout=3,
)

try:
    manager.start()
    for _ in range(20):  # Monitor for a while
        print("Worker status:", manager.status())
        time.sleep(2)
finally:
    manager.stop()
```
"""
import os
import time
import logging
import threading

from typing import (
    Callable,
    Optional,
    Union,
    Iterable,
)

from duck.exceptions.all import SettingsError

try:
    from duck.logging import logger
except SettingsError:
    from duck.logging import console as logger


[docs] class HeartbeatUpdateNeverCalled(Exception): """ Raised by `HeartbeatHealthCheck.check_health` if heartbeats are empty. """
[docs] class HeartbeatHealthCheck: """ Thread Health Check using heartbeat approach. Example: ```py healthcheck = HeartbeatHealthCheck(...) def worker_fn(idx, healthcheck, ...): while True: healthcheck.update_heartbeat(idx) # Some tasks here ... ``` """ def __init__(self, heartbeat_timeout: float): """ Initialize heartbeat health check. """ self.heartbeat_timeout = heartbeat_timeout self._heartbeats = {}
[docs] def update_heartbeat(self, idx: int): """ Update last heartbeat. Args: idx (int): Index of the thread, usually provided to `worker_fn`. Raises: RuntimeError: If the function is called in main thread or not in a child thread. """ if threading.current_thread() == threading.main_thread(): raise RuntimeError("This method must be used in a child thread, not main thread.") self._heartbeats[idx] = time.time()
[docs] def check_health(self, thread: threading.Thread, idx: int) -> bool: """ Checks if last heartbeat hasn't reached a timeout. This may indicate an unhealthy thread. Returns: bool: True if last heartbeat hasn't reached a timeout else False. Raises: HeartbeatUpdateNeverCalled: Raised if no heartbeat update has never been called. This avoids mistakenly using this approach but not upating heartbeats by calling `update_heartbeat`. In a thread loop, heartbeat update must be called initialialy before handling any tasks. Example: ```py healthcheck = HeartbeatHealthCheck(...) def worker_fn(idx, healthcheck, ...): while True: healthcheck.update_heartbeat(idx) # Some tasks here ... ``` """ if not self._heartbeats: raise HeartbeatUpdateNeverCalled("Heartbeats are empty, meaning you may not be calling `update_heartbeat` in your child thread.") last_beat = self._heartbeats.get(idx, 0) if time.time() - last_beat > self.heartbeat_timeout: return False # Too long since last heartbeat return True
[docs] def __call__(self, thread: threading.Thread, idx: int) -> bool: """ Checks if last heartbeat hasn't reached a timeout. This may indicate an unhealthy thread. Returns: bool: True if last heartbeat hasn't reached a timeout else False. Raises: HeartbeatUpdateNeverCalled: Raised if no heartbeat update has never been updated. This avoids mistakenly using this approach but not upating heartbeats by calling `update_heartbeat`. In a thread loop, heartbeat update must be called initialialy before handling any tasks. Example: ```py healthcheck = HeartbeatHealthCheck(...) manager = WorkerThreadManager( health_check_fn=healthcheck, ... ) def worker_fn(idx, healthcheck, ...): while True: healthcheck.update_heartbeat(idx) # Some tasks here ... ``` """ return self.check_health(thread, idx)
[docs] class WorkerThreadManager: """ WorkerThreadManager manages and monitors a pool of worker threads. **Features:** - Automatic restart of dead or unhealthy workers. - Customizable health-check hooks per worker. - Threaded non-blocking monitoring loop. - Configurable logging and verbosity. - Status inspection and graceful shutdown. Example use cases: - WSGI/ASGI server worker orchestration. - ML/AI multi-thread task runner watchdog. - Long-running web backend with thread self-repair. """ def __init__( self, worker_fn: Callable, num_workers: int, args_fn: Optional[Callable[[int], tuple]] = None, worker_name_fn: Optional[Callable[[int], str]] = None, health_check_fn: Optional[Union[Callable[[threading.Thread], bool], HeartbeatHealthCheck]] = None, restart_timeout: Union[int, float] = 5, enable_logs: bool = True, verbose_logs: bool = True, enable_monitoring: bool = True, thread_stop_timeout: Optional[float] = 5.0, daemon: bool = False, ): """ Args: worker_fn (Callable): Function executed by each worker thread. num_workers (int): Number of worker threads to start. args_fn (Optional[Callable]): Callable (idx) => tuple for args per worker. worker_name_fn (Optional[Callable]): Callable (idx) => str; worker thread name. health_check_fn (Optional[Union[Callable[[threading.Thread], bool], HeartbeatHealthCheck): Callable (Thread) => bool: Function to check health; must return True if worker healthy, False otherwise. You can just supply `HeartbeatHealthCheck` object instead to use heartbeat health check. restart_timeout (int|float): Seconds to wait before restart on thread death. enable_logs (bool): Enable info/warning logging. verbose_logs (bool): Enable full exception trace logs. enable_monitoring (bool): Start monitor thread automatically. thread_stop_timeout (Optional[float]): Maximum seconds to wait for worker to stop. Will be parsed to `join()` method. daemon (bool): Whether to start daemon threads. Defaults to False. """ self.worker_fn = worker_fn self.num_workers = num_workers self.args_fn = args_fn or (lambda idx: ()) self.worker_name_fn = worker_name_fn or (lambda idx: f"worker-{idx}") self.health_check_fn = health_check_fn self.restart_timeout = restart_timeout self.enable_logs = enable_logs self.verbose_logs = verbose_logs self.enable_monitoring = enable_monitoring self.thread_stop_timeout = thread_stop_timeout self.daemon = daemon self.worker_threads = [] self.worker_locks = [threading.Lock() for _ in range(num_workers)] self.running = False self.monitor_thread = None def worker_fn_wrapper(*args, **kwargs): self.worker_fn(*args, **kwargs) # Assign wrapper; will be called for starting child thread self.worker_fn_wrapper = worker_fn_wrapper
[docs] def start(self): """ Start worker threads and non-blocking monitor loop. """ self.running = True self.worker_threads = [] for i in range(self.num_workers): args = self.args_fn(i) args = (i, *args) if isinstance(args, Iterable) else (i, args) # Always include index in args if len(args) == 2: if args[1] is None: args = (i, ) if isinstance(self.health_check_fn, HeartbeatHealthCheck): # Parse HeartbeatHealthCheck object in worker_fn args = list(args) args.insert(1, self.health_check_fn) # Start the child thread name = self.worker_name_fn(i) t = threading.Thread( target=self.worker_fn_wrapper, args=args, name=name, ) t.start() self.worker_threads.append(t) if self.enable_monitoring: self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) self.monitor_thread.start() if self.enable_logs: logger.log( f"Started {self.num_workers} worker {'threads' if self.num_workers != 1 else 'thread'}; Monitoring: {'ON' if self.enable_monitoring else 'OFF'}", level=logger.DEBUG, )
[docs] def stop( self, wait: bool = True, monitor_stop_timeout: float = 0.5, no_logging: bool = False, ): """ Stop all worker threads and monitoring thread. Args: wait (bool): Whether to wait for threads to finish stopping. Defaults to True. monitor_stop_timeout (float): Timeout for waiting on monitor thread. no_logging (bool): Whether to log stop message. Use this to temporarily disable logging of stop message. """ self.running = False # Stop monitoring thread first. if wait and self.monitor_thread and self.monitor_thread.is_alive(): self.monitor_thread.join(timeout=monitor_stop_timeout) if wait: for i, t in enumerate(self.worker_threads): t.join(timeout=self.thread_stop_timeout) if self.enable_logs and t.is_alive() and not no_logging: logger.log( f"Worker thread {t.name} did not shut down gracefully.", level=logger.WARNING, ) if self.enable_logs and not no_logging: logger.log( "All workers and monitor stopped." if wait else "Stopped worker thread manager.", level=logger.INFO, custom_color=logger.Fore.MAGENTA, )
[docs] def _restart_worker(self, idx: int): """ Restart a worker thread by index. """ with self.worker_locks[idx]: old_t = self.worker_threads[idx] if old_t.is_alive(): old_t.join(timeout=5) # Start new thread args = self.args_fn(idx) args = (idx, *args) if isinstance(args, Iterable) else (idx, args) # Always include index in args if len(args) == 2: if args[1] is None: args = (idx, ) if isinstance(self.health_check_fn, HeartbeatHealthCheck): # Parse HeartbeatHealthCheck object in worker_fn args = list(args) args.insert(1, self.health_check_fn) # Start the child thread name = self.worker_name_fn(idx) new_t = threading.Thread( target=self.worker_fn_wrapper, args=args, name=name, daemon=self.daemon, ) # Start and update worker threads new_t.start() self.worker_threads[idx] = new_t # Log something if logs enabled. if self.enable_logs: logger.log( f"Restarted worker thread {name} \n", level=logger.WARNING, )
[docs] def _monitor_loop(self): """ Monitor thread: checks worker health/liveness and restarts unhealthy/dead workers. Non-blocking for main thread. """ time.sleep(2) # Sleep a little heartbeat_never_called_counter = 0 while self.running: try: for idx, t in enumerate(list(self.worker_threads)): healthy = t.is_alive() if healthy and self.health_check_fn: try: healthy = self.health_check_fn(t, idx) except HeartbeatUpdateNeverCalled as e: healthy = False if self.enable_logs and heartbeat_never_called_counter > 0: # Don't log first error of this type, give . logger.log(f"Exception during health_check_fn: {e}", level=logger.WARNING) if self.verbose_logs: logger.log_exception(e) # Wait for heartbeat_timeout and continue heartbeat_never_called_counter += 1 time.sleep(self.health_check_fn.heartbeat_timeout) continue except (KeyboardInterrupt, BrokenPipeError): # Thread might be terminated break except Exception as e: healthy = False if self.enable_logs: logger.log(f"Exception during health_check_fn: {e}", level=logger.WARNING) if self.verbose_logs: logger.log_exception(e) time.sleep(2) continue # Thread is not healthy if not healthy: if self.enable_logs: logger.log( f"Detected unhealthy/dead worker {t.name}, restarting...", level=logger.WARNING, ) self._restart_worker(idx) time.sleep(self.restart_timeout) # Sleep a little time.sleep(2) except Exception as e: if self.enable_logs: logger.log(f"Error in monitor loop: {e}", level=logger.WARNING) if self.verbose_logs: logger.log_exception(e)
[docs] def status(self): """ Returns status list for all worker threads. Each dict contains (name, alive). """ status = [] for i, t in enumerate(self.worker_threads): status.append({ "name": t.name, "alive": t.is_alive() }) return status