Python helpers to limit the number of threads used in threadpool-backed native libraries for scientific computing
Extend threadpoolctl with custom library controllers for additional thread pool libraries not supported out of the box. This capability allows integration of custom or newer thread pool implementations.
Abstract base class that defines the interface for all library controllers.
class LibController:
"""
Abstract base class for individual library controllers.
Subclasses must define class attributes and implement abstract methods
to support a specific thread pool library implementation.
Class Attributes (required):
user_api: str
Standardized API name ('blas', 'openmp', or custom)
internal_api: str
Implementation-specific identifier (unique name)
filename_prefixes: tuple[str, ...]
Shared library filename prefixes for detection
check_symbols: tuple[str, ...] (optional)
Symbol names to verify library compatibility
"""
def __init__(self, *, filepath=None, prefix=None, parent=None):
"""
Initialize library controller (not meant to be overridden).
Args:
filepath: str - Path to shared library
prefix: str - Matched filename prefix
parent: ThreadpoolController - Parent controller instance
"""
def info(self):
"""
Return library information dictionary.
Returns:
dict: Library info with standard keys plus any additional attributes
"""
@property
def num_threads(self):
"""
Current thread limit (dynamic property).
Returns:
int: Current maximum thread count
"""
def get_num_threads(self):
"""
Get current maximum thread count (abstract method).
Must be implemented by subclasses.
Returns:
int | None: Current thread limit, None if unavailable
"""
def set_num_threads(self, num_threads):
"""
Set maximum thread count (abstract method).
Must be implemented by subclasses.
Args:
num_threads: int - New thread limit
Returns:
Any: Implementation-specific return value
"""
def get_version(self):
"""
Get library version (abstract method).
Must be implemented by subclasses.
Returns:
str | None: Version string, None if unavailable
"""
def set_additional_attributes(self):
"""
Set implementation-specific attributes.
Called during initialization to set custom attributes that will
be included in the info() dictionary. Override to add custom fields.
"""Register custom controllers with the threadpoolctl system.
def register(controller):
"""
Register a new library controller class.
Adds the controller to the global registry so it will be used
during library discovery in future ThreadpoolController instances.
Args:
controller: type[LibController] - LibController subclass to register
"""from threadpoolctl import LibController, register
import ctypes
class MyCustomController(LibController):
"""Controller for a custom thread pool library."""
# Required class attributes
user_api = "custom" # Or "blas"/"openmp" if it fits those categories
internal_api = "mycustomlib"
filename_prefixes = ("libmycustom", "mycustomlib")
# Optional: symbols to check for library compatibility
check_symbols = (
"mycustom_get_threads",
"mycustom_set_threads",
"mycustom_get_version"
)
def get_num_threads(self):
"""Get current thread count from the library."""
get_func = getattr(self.dynlib, "mycustom_get_threads", None)
if get_func is not None:
return get_func()
return None
def set_num_threads(self, num_threads):
"""Set thread count in the library."""
set_func = getattr(self.dynlib, "mycustom_set_threads", None)
if set_func is not None:
return set_func(num_threads)
return None
def get_version(self):
"""Get library version string."""
version_func = getattr(self.dynlib, "mycustom_get_version", None)
if version_func is not None:
version_func.restype = ctypes.c_char_p
version_bytes = version_func()
if version_bytes:
return version_bytes.decode('utf-8')
return None
def set_additional_attributes(self):
"""Set custom attributes for this library."""
# Add any custom attributes that should appear in info()
self.custom_attribute = self._get_custom_info()
def _get_custom_info(self):
"""Helper method to get custom library information."""
# Implementation-specific logic
return "custom_value"
# Register the controller
register(MyCustomController)from threadpoolctl import LibController, register
import ctypes
class CustomBLASController(LibController):
"""Controller for a custom BLAS implementation."""
user_api = "blas"
internal_api = "customblas"
filename_prefixes = ("libcustomblas",)
check_symbols = (
"customblas_get_num_threads",
"customblas_set_num_threads",
"customblas_get_config"
)
def get_num_threads(self):
get_func = getattr(self.dynlib, "customblas_get_num_threads", None)
if get_func:
threads = get_func()
# Handle special return values (like OpenBLAS -1 = sequential)
return 1 if threads == -1 else threads
return None
def set_num_threads(self, num_threads):
set_func = getattr(self.dynlib, "customblas_set_num_threads", None)
if set_func:
return set_func(num_threads)
return None
def get_version(self):
config_func = getattr(self.dynlib, "customblas_get_config", None)
if config_func:
config_func.restype = ctypes.c_char_p
config = config_func()
if config:
# Parse version from config string
config_str = config.decode('utf-8')
# Extract version using custom parsing logic
return self._parse_version(config_str)
return None
def set_additional_attributes(self):
"""Add BLAS-specific attributes."""
self.threading_layer = self._get_threading_layer()
self.architecture = self._get_architecture()
def _parse_version(self, config_str):
"""Parse version from config string."""
# Custom parsing logic
import re
match = re.search(r'version\s+(\d+\.\d+\.\d+)', config_str, re.IGNORECASE)
return match.group(1) if match else None
def _get_threading_layer(self):
"""Determine threading layer."""
# Custom logic to determine threading backend
return "openmp" # or "pthreads", "disabled", etc.
def _get_architecture(self):
"""Get target architecture."""
# Custom logic to get architecture info
return "x86_64"
register(CustomBLASController)from threadpoolctl import LibController, register
import ctypes
class FlexibleController(LibController):
"""Controller that adapts to different library versions."""
user_api = "custom"
internal_api = "flexlib"
filename_prefixes = ("libflex",)
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Determine available methods during initialization
self._detect_capabilities()
def _detect_capabilities(self):
"""Detect which methods are available in this library version."""
self.has_v1_api = hasattr(self.dynlib, "flex_get_threads_v1")
self.has_v2_api = hasattr(self.dynlib, "flex_get_threads_v2")
self.has_version_api = hasattr(self.dynlib, "flex_version")
def get_num_threads(self):
if self.has_v2_api:
return self.dynlib.flex_get_threads_v2()
elif self.has_v1_api:
return self.dynlib.flex_get_threads_v1()
return None
def set_num_threads(self, num_threads):
if self.has_v2_api:
return self.dynlib.flex_set_threads_v2(num_threads)
elif self.has_v1_api:
return self.dynlib.flex_set_threads_v1(num_threads)
return None
def get_version(self):
if self.has_version_api:
version_func = self.dynlib.flex_version
version_func.restype = ctypes.c_char_p
return version_func().decode('utf-8')
return None
def set_additional_attributes(self):
self.api_version = "v2" if self.has_v2_api else ("v1" if self.has_v1_api else "unknown")
register(FlexibleController)from threadpoolctl import ThreadpoolController, threadpool_info
# After registering custom controllers, they work like built-in ones
controller = ThreadpoolController()
# Custom controllers appear in standard introspection
info = threadpool_info()
custom_libs = [lib for lib in info if lib['user_api'] == 'custom']
print(f"Found {len(custom_libs)} custom thread pool libraries")
# Custom controllers work with selection and limiting
custom_controller = controller.select(internal_api='mycustomlib')
if custom_controller:
with custom_controller.limit(limits=2):
# Custom library limited to 2 threads
result = computation_using_custom_library()from threadpoolctl import LibController, register
import ctypes
class RobustController(LibController):
"""Controller with comprehensive error handling."""
user_api = "custom"
internal_api = "robustlib"
filename_prefixes = ("librobust",)
def get_num_threads(self):
try:
get_func = getattr(self.dynlib, "robust_get_threads", None)
if get_func:
result = get_func()
# Validate result
if isinstance(result, int) and result > 0:
return result
except (AttributeError, OSError, ValueError) as e:
# Log error but don't crash
print(f"Warning: Could not get thread count from {self.internal_api}: {e}")
return None
def set_num_threads(self, num_threads):
if not isinstance(num_threads, int) or num_threads < 1:
raise ValueError(f"Invalid thread count: {num_threads}")
try:
set_func = getattr(self.dynlib, "robust_set_threads", None)
if set_func:
return set_func(num_threads)
except (AttributeError, OSError) as e:
print(f"Warning: Could not set thread count in {self.internal_api}: {e}")
return None
def get_version(self):
try:
version_func = getattr(self.dynlib, "robust_version", None)
if version_func:
version_func.restype = ctypes.c_char_p
version_bytes = version_func()
if version_bytes:
return version_bytes.decode('utf-8', errors='ignore')
except Exception as e:
print(f"Warning: Could not get version from {self.internal_api}: {e}")
return None
register(RobustController)Install with Tessl CLI
npx tessl i tessl/pypi-threadpoolctl