"""
Module housing decorators that are used to validate the arguments and return values of
functions.
"""
import inspect
import warnings
from functools import wraps
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
from macrosynergy import PYTHON_3_8_OR_LATER
if PYTHON_3_8_OR_LATER:
from typing import get_args, get_origin
else:
from typing_extensions import get_args, get_origin
from inspect import signature
import pandas as pd
import numpy as np
from packaging import version
try:
from macrosynergy import __version__ as MSY_VERSION
except ImportError:
try:
from setup import VERSION as MSY_VERSION
except ImportError:
MSY_VERSION = "0.0.0"
[docs]def deprecate(
new_func: Callable,
deprecate_version: str,
remove_after: str = None,
message: str = None,
macrosynergy_package_version: str = MSY_VERSION,
):
"""
Decorator for deprecating a function.
Parameters
----------
new_func : callable
The function that replaces the old one.
deprecate_version : str
The version in which the old function is deprecated.
remove_after : str
The version in which the old function is removed.
message : str
The message to display when the old function is called. This message must
contain the following format strings: "{old_method}", "{deprecate_version}", and
"{new_method}". If None, the default message is used.
macrosynergy_package_version : str
The version of the macrosynergy package. This is used to determine if the
deprecation warning should be shown.
Returns
-------
callable
The decorated function.
"""
def decorator(
old_func,
new_func=new_func,
deprecate_version=deprecate_version,
remove_after=remove_after,
message=message,
macrosynergy_package_version=macrosynergy_package_version,
):
# if the message is none, use the default message
if message is None:
message = (
"{old_method} was deprecated in version {deprecate_version} and will be "
"removed in version. Use {new_method} instead."
)
# else if the message does not have "{old_method}" in it, fail
else:
if any(
[
fs not in message
for fs in ["{old_method}", "{deprecate_version}", "{new_method}"]
]
):
raise ValueError(
"The message must contain the following format strings: "
"'{old_method}', '{deprecate_version}', and '{new_method}'."
)
# if remove_after is not None, check the version
if remove_after is not None:
try:
version.parse(remove_after)
except version.InvalidVersion as e:
raise ValueError(
f"The version in which the function is deprecated ({remove_after}) "
f"must be a valid version string."
) from e
if version.parse(deprecate_version) > version.parse(remove_after):
raise ValueError(
f"The version in which the old function will be removed "
f"({remove_after}) "
f"must be greater than the version in which it is deprecated "
f"({deprecate_version})."
)
else:
remove_after = MSY_VERSION
@wraps(old_func)
# This will ensure the old function retains its name and other properties.
def wrapper(*args, **kwargs):
if version.parse(deprecate_version) < version.parse(
macrosynergy_package_version
):
warnings.warn(
message.format(
old_method=old_func.__name__,
new_method=new_func.__name__,
deprecate_version=deprecate_version,
),
FutureWarning,
)
return old_func(*args, **kwargs)
# Update the signature and docstring of the old function to match the new one.
wrapper.__signature__ = signature(new_func)
wrapper.__doc__ = new_func.__doc__
return wrapper
return decorator
[docs]def is_matching_subscripted_type(value: Any, type_hint: Type[Any]) -> bool:
"""
Implementation of `insinstance()` for type-hints imported from the `typing` module,
and for subscripted types (e.g. `List[int]`, `Tuple[str, int]`, etc.).
Parameters
----------
value : Any
The value to check.
type_hint : Type[Any]
The type hint to check against.
Returns
-------
bool
True if the value is of the type hint, False otherwise.
"""
origin = get_origin(type_hint)
args = get_args(type_hint)
# handling lists
if origin in [list, List]:
if not isinstance(value, list):
return False
return all(isinstance(item, args[0]) for item in value)
# tuples
if origin in [tuple, Tuple]:
if not isinstance(value, tuple) or len(value) != len(args):
return False
# don't switch order of get_origin and is_matching_subscripted_type, is
# short-circuiting
return all(
[
(get_origin(expected) and is_matching_subscripted_type(item, expected))
or isinstance(item, expected)
for item, expected in zip(value, args)
]
)
# dicts
if origin in [dict, Dict]:
if not isinstance(value, dict):
return False
key_type, value_type = args
return all(
[
(get_origin(key_type) and is_matching_subscripted_type(k, key_type))
or isinstance(k, key_type)
or (isinstance(k, key_type) and isinstance(v, value_type))
for k, v in value.items()
]
)
# unions and optionals
if origin is Union:
for possible_type in args:
if get_origin(possible_type): # is subscripted
if is_matching_subscripted_type(value, possible_type):
return True
elif isinstance(value, possible_type):
return True
return False
return False
[docs]def get_expected_type(arg_type_hint: Type[Any]) -> List[str]:
"""
Based on the type hint, return a list of strings that represent the type hint -
including any nested type hints.
Parameters
----------
arg_type_hint : Type[Any]
The type hint to get the expected types for.
Returns
-------
List[str]
A list of strings that represent the type hint.
"""
origin = get_origin(arg_type_hint)
args = get_args(arg_type_hint)
# handling lists
if origin in [list, List]:
return [f"List[{get_expected_type(args[0])[0]}]"]
# tuples
if origin in [tuple, Tuple]:
return [f"Tuple[{', '.join(get_expected_type(arg) for arg in args)}]"]
# dicts
if origin in [dict, Dict]:
return [f"Dict[{', '.join(get_expected_type(arg) for arg in args)}]"]
# unions and optionals
if origin in [Union, Optional]:
# get a flat list of all the expected types
expected_types: List[str] = []
for possible_type in args:
if get_origin(possible_type):
expected_types.extend(get_expected_type(possible_type))
else:
expected_types.append(str(possible_type))
return expected_types
return [str(arg_type_hint)]
[docs]def argvalidation(func: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator for validating the arguments and return value of a function.
Parameters
----------
func : Callable[..., Any]
The function to validate.
Returns
-------
Callable[..., Any]
The decorated function.
"""
def format_expected_type(expected_types: List[Any]) -> str:
# format the expected types to read nicely, and to remove 'typing.' from the string
if isinstance(expected_types, tuple):
expected_types = list(expected_types)
for i, et in enumerate(expected_types):
if str(et).startswith("typing."):
expected_types[i] = str(et).replace("typing.", "")
if et is type(None):
expected_types[i] = "None"
ret_string = (
f"{', '.join([f'`{t}`' for t in expected_types[:-1]])}, "
f"or `{expected_types[-1]}`"
)
if len(expected_types) == 1:
return f"`{expected_types[0]}`"
elif len(expected_types) == 2:
return f"`{expected_types[0]}` or `{expected_types[1]}`"
else:
return ret_string
@wraps(func)
def validation_wrapper(*args: Any, **kwargs: Any) -> Any:
func_sig: inspect.Signature = inspect.signature(func)
func_params: Dict[str, inspect.Parameter] = func_sig.parameters
func_annotations: Dict[str, Any] = func_sig.return_annotation
func_args: Dict[str, Any] = inspect.getcallargs(func, *args, **kwargs)
# validate the arguments
for arg_name, arg_value in func_args.items():
if arg_name in func_params:
arg_type: Type[Any] = func_params[arg_name].annotation
if arg_type is not inspect._empty:
origin = get_origin(arg_type)
if origin: # Handling subscripted types
# replace 'float' with 'typng.Union[float, int]' to make life easier
if not is_matching_subscripted_type(arg_value, arg_type):
exp_types: str = format_expected_type(get_args(arg_type))
raise TypeError(
f"Argument `{arg_name}` must be of type {exp_types}, "
f"not `{type(arg_value).__name__}` (with value "
f"`{arg_value}`)."
)
else: # For simple, non-generic types
if not isinstance(arg_value, arg_type):
raise TypeError(
f"Argument `{arg_name}` must be of type `{arg_type}`, "
f"not `{type(arg_value).__name__}` (with value "
f"`{arg_value}`)."
)
# validate the return value
return_value: Any = func(*args, **kwargs)
if func_annotations is not inspect._empty:
origin = get_origin(func_annotations)
if (
origin
and (not is_matching_subscripted_type(return_value, func_annotations))
) or (not origin and not isinstance(return_value, func_annotations)):
exp_types: str = format_expected_type(get_args(func_annotations))
raise warnings.warn(
f"Return value of `{func.__name__}` is not of type "
f"`{func_annotations}`, but of type `{type(return_value)}`."
)
return return_value
return validation_wrapper
[docs]def argcopy(func: Callable) -> Callable:
"""
Decorator for applying a "pass-by-value" method to the arguments of a function.
Parameters
----------
func : Callable
The function to copy arguments for.
Returns
-------
Callable
The decorated function.
"""
@wraps(func)
def copy_wrapper(*args, **kwargs):
copy_types = (
list,
dict,
np.ndarray,
pd.Series,
set,
)
new_args: List[Tuple[Any, ...]] = []
for arg in args:
if isinstance(arg, copy_types) or issubclass(type(arg), copy_types):
new_args.append(arg.copy())
else:
new_args.append(arg)
new_kwargs: Dict[str, Any] = {}
for key, value in kwargs.items():
if isinstance(value, copy_types) or issubclass(type(value), copy_types):
new_kwargs[key] = value.copy()
else:
new_kwargs[key] = value
return func(*new_args, **new_kwargs)
return copy_wrapper