Source code for duck.utils.threading.patch

"""
Production-grade monkey-patch module for Python's threading.Thread.  

Enhancements:
- Tracks parent thread object (not just ident).
- Returns parent Thread object via get_parent_thread().
- Uses strong weakref-based registry (no leaks).
- Supports patching already-created Thread instances (before start()).
- Preserves subclassed run() methods.
- Hooks for pre-run and post-run execution.
- Fully idempotent. Safe for large production systems.
- Automatic cleanup after thread finishes.
"""

import threading
import functools
import weakref

from typing import (
    Callable,
    Optional,
    Dict,
    Any,
    Union,
)


thread_info: Dict[int, Dict[str, Any]] = {}
_is_patched = False

# Cached original class methods
_original_init = threading.Thread.__init__


[docs] class PatchNotApplied(Exception): """ Raised if user tries to use `get_parent` but forgot to patch threading module. """
[docs] def _wrap_run(self, original_run, pre_hook, post_hook): """ Returns a wrapped run() method that executes: - pre_hook() - original run() - post_hook() - registry cleanup """ @functools.wraps(original_run) def wrapped(): ident = threading.get_ident() # register metadata now that thread started and ident exists thread_info[ident] = { "parent": weakref.ref(self._parent_thread) if hasattr(self, "_parent_thread") else None, "thread": weakref.ref(self), } # pre-hook if pre_hook: try: pre_hook(self) except Exception: pass try: return original_run() finally: # post-hook if post_hook: try: post_hook(self) except Exception: pass # cleanup thread_info.pop(ident, None) return wrapped
[docs] def patch_threading( *, pre_hook: Optional[Callable[[threading.Thread], None]] = None, post_hook: Optional[Callable[[threading.Thread], None]] = None, patch_existing_threads: bool = False, ) -> None: """ Monkey-patches threading.Thread so that: - New threads automatically track parent thread objects. - `.run()` is wrapped at instance level. - Subclass overrides continue to work. - Optionally patch existing Thread objects created before patching. - Automatic cleanup in registry. Args: pre_hook (Optional): runs before each thread's original run(). post_hook (Optional): runs after each thread's run() completes. patch_existing_threads (bool): If True, already-created threads that have NOT started will get their `.run()` patched as well. Notes: - Idempotent: calling twice does nothing. - Back-patching cannot discover parent thread for already-created threads. """ global _is_patched if _is_patched: return def patched_init(self, *args, **kwargs): parent_thread = threading.current_thread() self._parent_thread = parent_thread # store the actual object # original init _original_init(self, *args, **kwargs) # capture user-defined run() (may be overridden) original_run = self.run # replace with wrapped version self.run = _wrap_run(self, original_run, pre_hook, post_hook) # apply patch threading.Thread.__init__ = patched_init _is_patched = True # optionally patch threads created before this call if patch_existing_threads: for t in threading.enumerate(): if isinstance(t, threading.Thread) and not hasattr(t, "_parent_thread"): if not t.is_alive(): # we can only patch before start t._parent_thread = threading.current_thread() original_run = t.run t.run = _wrap_run(t, original_run, pre_hook, post_hook)
[docs] def get_parent_thread(thread_or_ident: Union[int, threading.Thread]) -> Optional[threading.Thread]: """ Returns the actual parent Thread object of a given thread. Args: thread_or_ident: A thread object or its ident. Returns: Thread | None Raises: PatchNotApplied: if thread module wasn't patched yet. """ if not _is_patched: raise PatchNotApplied('threading module has not been patched yet. Did you forget to use "patch_threading".') ident = ( thread_or_ident if isinstance(thread_or_ident, int) else getattr(thread_or_ident, "ident", None) ) if ident is None: return None info = thread_info.get(ident) if not info: return None parent_ref = info.get("parent") return parent_ref() if parent_ref else None
# Alias to stay backwards compatible with your previous API get_parent = get_parent_thread