diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index c8e76a90..94576ee3 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,10 +1,13 @@ +import re from collections.abc import Mapping from functools import lru_cache -from typing import Any, NamedTuple, Sequence, Tuple, Union +from inspect import signature +from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union from warnings import warn from . import _array_module as xp from ._array_module import _UndefinedStub +from .stubs import name_to_func from .typing import DataType, ScalarType __all__ = [ @@ -242,67 +245,31 @@ def result_type(*dtypes: DataType): return result -func_in_dtypes = { - # elementwise - "abs": numeric_dtypes, - "acos": float_dtypes, - "acosh": float_dtypes, - "add": numeric_dtypes, - "asin": float_dtypes, - "asinh": float_dtypes, - "atan": float_dtypes, - "atan2": float_dtypes, - "atanh": float_dtypes, - "bitwise_and": bool_and_all_int_dtypes, - "bitwise_invert": bool_and_all_int_dtypes, - "bitwise_left_shift": all_int_dtypes, - "bitwise_or": bool_and_all_int_dtypes, - "bitwise_right_shift": all_int_dtypes, - "bitwise_xor": bool_and_all_int_dtypes, - "ceil": numeric_dtypes, - "cos": float_dtypes, - "cosh": float_dtypes, - "divide": float_dtypes, - "equal": all_dtypes, - "exp": float_dtypes, - "expm1": float_dtypes, - "floor": numeric_dtypes, - "floor_divide": numeric_dtypes, - "greater": numeric_dtypes, - "greater_equal": numeric_dtypes, - "isfinite": numeric_dtypes, - "isinf": numeric_dtypes, - "isnan": numeric_dtypes, - "less": numeric_dtypes, - "less_equal": numeric_dtypes, - "log": float_dtypes, - "logaddexp": float_dtypes, - "log10": float_dtypes, - "log1p": float_dtypes, - "log2": float_dtypes, - "logical_and": (xp.bool,), - "logical_not": (xp.bool,), - "logical_or": (xp.bool,), - "logical_xor": (xp.bool,), - "multiply": numeric_dtypes, - "negative": numeric_dtypes, - "not_equal": all_dtypes, - "positive": numeric_dtypes, - "pow": numeric_dtypes, - "remainder": numeric_dtypes, - "round": numeric_dtypes, - "sign": numeric_dtypes, - "sin": float_dtypes, - "sinh": float_dtypes, - "sqrt": float_dtypes, - "square": numeric_dtypes, - "subtract": numeric_dtypes, - "tan": float_dtypes, - "tanh": float_dtypes, - "trunc": numeric_dtypes, - # searching - "where": all_dtypes, +r_alias = re.compile("[aA]lias") +r_in_dtypes = re.compile("x1?: array\n.+have an? (.+) data type.") +r_int_note = re.compile( + "If one or both of the input arrays have integer data types, " + "the result is implementation-dependent" +) +category_to_dtypes = { + "boolean": (xp.bool,), + "integer": all_int_dtypes, + "floating-point": float_dtypes, + "numeric": numeric_dtypes, + "integer or boolean": bool_and_all_int_dtypes, } +func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {} +for name, func in name_to_func.items(): + if m := r_in_dtypes.search(func.__doc__): + dtype_category = m.group(1) + if dtype_category == "numeric" and r_int_note.search(func.__doc__): + dtype_category = "floating-point" + dtypes = category_to_dtypes[dtype_category] + func_in_dtypes[name] = dtypes + elif any("x" in name for name in signature(func).parameters.keys()): + func_in_dtypes[name] = all_dtypes +# See https://github.com/data-apis/array-api/pull/413 +func_in_dtypes["expm1"] = float_dtypes func_returns_bool = { @@ -365,6 +332,8 @@ def result_type(*dtypes: DataType): "trunc": False, # searching "where": False, + # linalg + "matmul": False, } @@ -408,7 +377,7 @@ def result_type(*dtypes: DataType): "__gt__": "greater", "__le__": "less_equal", "__lt__": "less", - # '__matmul__': 'matmul', # TODO: support matmul + "__matmul__": "matmul", "__mod__": "remainder", "__mul__": "multiply", "__ne__": "not_equal", @@ -440,6 +409,14 @@ def result_type(*dtypes: DataType): func_returns_bool[iop] = func_returns_bool[op] +func_in_dtypes["__bool__"] = (xp.bool,) +func_in_dtypes["__int__"] = all_int_dtypes +func_in_dtypes["__index__"] = all_int_dtypes +func_in_dtypes["__float__"] = float_dtypes +func_in_dtypes["from_dlpack"] = numeric_dtypes +func_in_dtypes["__dlpack__"] = numeric_dtypes + + @lru_cache def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: f_types = [] diff --git a/array_api_tests/meta/test_signatures.py b/array_api_tests/meta/test_signatures.py new file mode 100644 index 00000000..2efe1881 --- /dev/null +++ b/array_api_tests/meta/test_signatures.py @@ -0,0 +1,67 @@ +from inspect import Parameter, Signature, signature + +import pytest + +from ..test_signatures import _test_inspectable_func + + +def stub(foo, /, bar=None, *, baz=None): + pass + + +stub_sig = signature(stub) + + +@pytest.mark.parametrize( + "sig", + [ + Signature( + [ + Parameter("foo", Parameter.POSITIONAL_ONLY), + Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("baz", Parameter.KEYWORD_ONLY), + ] + ), + Signature( + [ + Parameter("foo", Parameter.POSITIONAL_ONLY), + Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("baz", Parameter.POSITIONAL_OR_KEYWORD), + ] + ), + Signature( + [ + Parameter("foo", Parameter.POSITIONAL_ONLY), + Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("qux", Parameter.KEYWORD_ONLY), + Parameter("baz", Parameter.KEYWORD_ONLY), + ] + ), + ], +) +def test_good_sig_passes(sig): + _test_inspectable_func(sig, stub_sig) + + +@pytest.mark.parametrize( + "sig", + [ + Signature( + [ + Parameter("foo", Parameter.POSITIONAL_ONLY), + Parameter("bar", Parameter.POSITIONAL_ONLY), + Parameter("baz", Parameter.KEYWORD_ONLY), + ] + ), + Signature( + [ + Parameter("foo", Parameter.POSITIONAL_ONLY), + Parameter("bar", Parameter.KEYWORD_ONLY), + Parameter("baz", Parameter.KEYWORD_ONLY), + ] + ), + ], +) +def test_raises_on_bad_sig(sig): + with pytest.raises(AssertionError): + _test_inspectable_func(sig, stub_sig) diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index 1ff1e1b6..35cc885f 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -40,18 +40,29 @@ if name.endswith("_functions"): category = name.replace("_functions", "") objects = [getattr(mod, name) for name in mod.__all__] - assert all(isinstance(o, FunctionType) for o in objects) + assert all(isinstance(o, FunctionType) for o in objects) # sanity check category_to_funcs[category] = objects +all_funcs = [] +for funcs in [array_methods, *category_to_funcs.values()]: + all_funcs.extend(funcs) +name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs} + EXTENSIONS: str = ["linalg"] extension_to_funcs: Dict[str, List[FunctionType]] = {} for ext in EXTENSIONS: mod = name_to_mod[ext] objects = [getattr(mod, name) for name in mod.__all__] - assert all(isinstance(o, FunctionType) for o in objects) - extension_to_funcs[ext] = objects + assert all(isinstance(o, FunctionType) for o in objects) # sanity check + funcs = [] + for func in objects: + if "Alias" in func.__doc__: + funcs.append(name_to_func[func.__name__]) + else: + funcs.append(func) + extension_to_funcs[ext] = funcs -all_funcs = [] -for funcs in [array_methods, *category_to_funcs.values(), *extension_to_funcs.values()]: - all_funcs.extend(funcs) -name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs} +for funcs in extension_to_funcs.values(): + for func in funcs: + if func.__name__ not in name_to_func.keys(): + name_to_func[func.__name__] = func diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index 2e197ee9..2db804b1 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -1,282 +1,254 @@ -import inspect -from itertools import chain +""" +Tests for function/method signatures compliance -import pytest +We're not interested in being 100% strict - instead we focus on areas which +could affect interop, e.g. with -from ._array_module import mod, mod_name, ones, eye, float64, bool, int64, _UndefinedStub -from .pytest_helpers import raises, doesnt_raise -from . import dtype_helpers as dh + def add(x1, x2, /): + ... -from . import stubs +x1 and x2 don't need to be pos-only for the purposes of interoperability, but with + def squeeze(x, /, axis): + ... -def extension_module(name) -> bool: - for funcs in stubs.extension_to_funcs.values(): - for func in funcs: - if name == func.__name__: - return True - else: - return False +axis has to be pos-or-keyword to support both styles + >>> squeeze(x, 0) + ... + >>> squeeze(x, axis=0) + ... -params = [] -for name in [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]: - if name in ["where", "expand_dims", "reshape"]: - params.append(pytest.param(name, marks=pytest.mark.skip(reason="faulty test"))) - else: - params.append(name) - - -for ext, name in [(ext, f.__name__) for ext, funcs in stubs.extension_to_funcs.items() for f in funcs]: - params.append(pytest.param(name, marks=pytest.mark.xp_extension(ext))) - - -def array_method(name) -> bool: - return name in [f.__name__ for f in stubs.array_methods] - -def function_category(name) -> str: - for category, funcs in chain(stubs.category_to_funcs.items(), stubs.extension_to_funcs.items()): - for func in funcs: - if name == func.__name__: - return category - -def example_argument(arg, func_name, dtype): - """ - Get an example argument for the argument arg for the function func_name - - The full tests for function behavior is in other files. We just need to - have an example input for each argument name that should work so that we - can check if the argument is implemented at all. - - """ - # Note: for keyword arguments that have a default, this should be - # different from the default, as the default argument is tested separately - # (it can have the same behavior as the default, just not literally the - # same value). - known_args = dict( - api_version='2021.1', - arrays=(ones((1, 3, 3), dtype=dtype), ones((1, 3, 3), dtype=dtype)), - # These cannot be the same as each other, which is why all our test - # arrays have to have at least 3 dimensions. - axis1=2, - axis2=2, - axis=1, - axes=(2, 1, 0), - copy=True, - correction=1.0, - descending=True, - # TODO: This will only work on the NumPy implementation. The exact - # value of the device keyword will vary across implementations, so we - # need some way to infer it or for libraries to specify a list of - # valid devices. - device='cpu', - dtype=float64, - endpoint=False, - fill_value=1.0, - from_=int64, - full_matrices=False, - k=1, - keepdims=True, - key=(0, 0), - indexing='ij', - mode='complete', - n=2, - n_cols=1, - n_rows=1, - num=2, - offset=1, - ord=1, - obj = [[[1, 1, 1], [1, 1, 1], [1, 1, 1]]], - other=ones((3, 3), dtype=dtype), - return_counts=True, - return_index=True, - return_inverse=True, - rtol=1e-10, - self=ones((3, 3), dtype=dtype), - shape=(1, 3, 3), - shift=1, - sorted=False, - stable=False, - start=0, - step=2, - stop=1, - # TODO: Update this to be non-default. See the comment on "device" above. - stream=None, - to=float64, - type=float64, - upper=True, - value=0, - x1=ones((1, 3, 3), dtype=dtype), - x2=ones((1, 3, 3), dtype=dtype), - x=ones((1, 3, 3), dtype=dtype), - ) - if not isinstance(bool, _UndefinedStub): - known_args['condition'] = ones((1, 3, 3), dtype=bool), - - if arg in known_args: - # Special cases: - - # squeeze() requires an axis of size 1, but other functions such as - # cross() require axes of size >1 - if func_name == 'squeeze' and arg == 'axis': - return 0 - # ones() is not invertible - # finfo requires a float dtype and iinfo requires an int dtype - elif func_name == 'iinfo' and arg == 'type': - return int64 - # tensordot args must be contractible with each other - elif func_name == 'tensordot' and arg == 'x2': - return ones((3, 3, 1), dtype=dtype) - # tensordot "axes" is either a number representing the number of - # contractible axes or a 2-tuple or axes - elif func_name == 'tensordot' and arg == 'axes': - return 1 - # The inputs to outer() must be 1-dimensional - elif func_name == 'outer' and arg in ['x1', 'x2']: - return ones((3,), dtype=dtype) - # Linear algebra functions tend to error if the input isn't "nice" as - # a matrix - elif arg.startswith('x') and func_name in [f.__name__ for f in stubs.extension_to_funcs["linalg"]]: - return eye(3) - return known_args[arg] - else: - raise RuntimeError(f"Don't know how to test argument {arg}. Please update test_signatures.py") - -@pytest.mark.parametrize('name', params) -def test_has_names(name): - if extension_module(name): - ext = next( - ext for ext, funcs in stubs.extension_to_funcs.items() - if name in [f.__name__ for f in funcs] +""" +from inspect import Parameter, Signature, signature +from types import FunctionType +from typing import Any, Callable, Dict, List, Literal, get_args + +import pytest +from hypothesis import given, note, settings +from hypothesis import strategies as st +from hypothesis.strategies import DataObject + +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import xps +from ._array_module import _UndefinedStub +from ._array_module import mod as xp +from .stubs import array_methods, category_to_funcs, extension_to_funcs +from .typing import Array, DataType + +pytestmark = pytest.mark.ci + +ParameterKind = Literal[ + Parameter.POSITIONAL_ONLY, + Parameter.VAR_POSITIONAL, + Parameter.POSITIONAL_OR_KEYWORD, + Parameter.KEYWORD_ONLY, + Parameter.VAR_KEYWORD, +] +ALL_KINDS = get_args(ParameterKind) +VAR_KINDS = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) +kind_to_str: Dict[ParameterKind, str] = { + Parameter.POSITIONAL_OR_KEYWORD: "pos or kw argument", + Parameter.POSITIONAL_ONLY: "pos-only argument", + Parameter.KEYWORD_ONLY: "keyword-only argument", + Parameter.VAR_POSITIONAL: "star-args (i.e. *args) argument", + Parameter.VAR_KEYWORD: "star-kwargs (i.e. **kwargs) argument", +} + + +def _test_inspectable_func(sig: Signature, stub_sig: Signature): + params = list(sig.parameters.values()) + stub_params = list(stub_sig.parameters.values()) + + non_kwonly_stub_params = [ + p for p in stub_params if p.kind != Parameter.KEYWORD_ONLY + ] + # sanity check + assert non_kwonly_stub_params == stub_params[: len(non_kwonly_stub_params)] + # We're not interested if the array module has additional arguments, so we + # only iterate through the arguments listed in the spec. + for i, stub_param in enumerate(non_kwonly_stub_params): + assert ( + len(params) >= i + 1 + ), f"Argument '{stub_param.name}' missing from signature" + param = params[i] + + # We're not interested in the name if it isn't actually used + if stub_param.kind not in [Parameter.POSITIONAL_ONLY, *VAR_KINDS]: + assert ( + param.name == stub_param.name + ), f"Expected argument '{param.name}' to be named '{stub_param.name}'" + + if stub_param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]: + f_stub_kind = kind_to_str[stub_param.kind] + assert param.kind == stub_param.kind, ( + f"{param.name} is a {kind_to_str[param.kind]}, " + f"but should be a {f_stub_kind}" + ) + + kwonly_stub_params = stub_params[len(non_kwonly_stub_params) :] + for stub_param in kwonly_stub_params: + assert ( + stub_param.name in sig.parameters.keys() + ), f"Argument '{stub_param.name}' missing from signature" + param = next(p for p in params if p.name == stub_param.name) + assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], ( + f"{param.name} is a {kind_to_str[param.kind]}, " + f"but should be a {f_stub_kind} " + f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})" ) - ext_mod = getattr(mod, ext) - assert hasattr(ext_mod, name), f"{mod_name} is missing the {function_category(name)} extension function {name}()" - elif array_method(name): - arr = ones((1, 1)) - if name not in [f.__name__ for f in stubs.array_methods]: - assert hasattr(arr, name), f"The array object is missing the attribute {name}" - else: - assert hasattr(arr, name), f"The array object is missing the method {name}()" - else: - assert hasattr(mod, name), f"{mod_name} is missing the {function_category(name)} function {name}()" - -@pytest.mark.parametrize('name', params) -def test_function_positional_args(name): - # Note: We can't actually test that positional arguments are - # positional-only, as that would require knowing the argument name and - # checking that it can't be used as a keyword argument. But argument name - # inspection does not work for most array library functions that are not - # written in pure Python (e.g., it won't work for numpy ufuncs). - - if extension_module(name): - return - - dtype = None - if (name.startswith('__i') and name not in ['__int__', '__invert__', '__index__'] - or name.startswith('__r') and name != '__rshift__'): - n = f'__{name[3:]}' - else: - n = name - in_dtypes = dh.func_in_dtypes.get(n, dh.float_dtypes) - if bool in in_dtypes: - dtype = bool - elif all(d in in_dtypes for d in dh.all_int_dtypes): - dtype = int64 - - if array_method(name): - if name == '__bool__': - _mod = ones((), dtype=bool) - elif name in ['__int__', '__index__']: - _mod = ones((), dtype=int64) - elif name == '__float__': - _mod = ones((), dtype=float64) - else: - _mod = example_argument('self', name, dtype) - elif '.' in name: - extension_module_name, name = name.split('.') - _mod = getattr(mod, extension_module_name) - else: - _mod = mod - stub_func = stubs.name_to_func[name] - - if not hasattr(_mod, name): - pytest.skip(f"{mod_name} does not have {name}(), skipping.") - if stub_func is None: - # TODO: Can we make this skip the parameterization entirely? - pytest.skip(f"{name} is not a function, skipping.") - mod_func = getattr(_mod, name) - argspec = inspect.getfullargspec(stub_func) - func_args = argspec.args - if func_args[:1] == ['self']: - func_args = func_args[1:] - nargs = [len(func_args)] - if argspec.defaults: - # The actual default values are checked in the specific tests - nargs.extend([len(func_args) - i for i in range(1, len(argspec.defaults) + 1)]) - - args = [example_argument(arg, name, dtype) for arg in func_args] - if not args: - args = [example_argument('x', name, dtype)] + + +def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]: + if func_name in dh.func_in_dtypes.keys(): + dtypes = dh.func_in_dtypes[func_name] + if hh.FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] + return st.sampled_from(dtypes) else: - # Duplicate the last positional argument for the n+1 test. - args = args + [args[-1]] - - kwonlydefaults = argspec.kwonlydefaults or {} - required_kwargs = {arg: example_argument(arg, name, dtype) for arg in argspec.kwonlyargs if arg not in kwonlydefaults} - - for n in range(nargs[0]+2): - if name == 'result_type' and n == 0: - # This case is not encoded in the signature, but isn't allowed. - continue - if n in nargs: - doesnt_raise(lambda: mod_func(*args[:n], **required_kwargs)) - elif argspec.varargs: - pass + return xps.scalar_dtypes() + + +def make_pretty_func(func_name: str, *args: Any, **kwargs: Any): + f_sig = f"{func_name}(" + f_sig += ", ".join(str(a) for a in args) + if len(kwargs) != 0: + if len(args) != 0: + f_sig += ", " + f_sig += ", ".join(f"{k}={v}" for k, v in kwargs.items()) + f_sig += ")" + return f_sig + + +matrixy_funcs: List[FunctionType] = [ + *category_to_funcs["linear_algebra"], + *extension_to_funcs["linalg"], +] +matrixy_names: List[str] = [f.__name__ for f in matrixy_funcs] +matrixy_names += ["__matmul__", "triu", "tril"] + + +@given(data=st.data()) +@settings(max_examples=1) +def _test_uninspectable_func( + func_name: str, func: Callable, stub_sig: Signature, array: Array, data: DataObject +): + skip_msg = ( + f"Signature for {func_name}() is not inspectable " + "and is too troublesome to test for otherwise" + ) + if func_name in [ + # 0d shapes + "__bool__", + "__int__", + "__index__", + "__float__", + # x2 elements must be >=0 + "pow", + "bitwise_left_shift", + "bitwise_right_shift", + # axis default invalid with 0d shapes + "sort", + # shape requirements + *matrixy_names, + ]: + pytest.skip(skip_msg) + + param_to_value: Dict[Parameter, Any] = {} + for param in stub_sig.parameters.values(): + if param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]: + pytest.skip( + skip_msg + f" (because '{param.name}' is a {kind_to_str[param.kind]})" + ) + elif param.default != Parameter.empty: + value = param.default + elif param.name in ["x", "x1"]: + dtypes = get_dtypes_strategy(func_name) + value = data.draw( + xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label=param.name + ) + elif param.name in ["x2", "other"]: + if param.name == "x2": + assert "x1" in [p.name for p in param_to_value.keys()] # sanity check + orig = next(v for p, v in param_to_value.items() if p.name == "x1") + else: + assert array is not None # sanity check + orig = array + value = data.draw( + xps.arrays(dtype=orig.dtype, shape=orig.shape), label=param.name + ) else: - # NumPy ufuncs raise ValueError instead of TypeError - raises((TypeError, ValueError), lambda: mod_func(*args[:n]), f"{name}() should not accept {n} positional arguments") - -@pytest.mark.parametrize('name', params) -def test_function_keyword_only_args(name): - if extension_module(name): - return - - if array_method(name): - _mod = ones((1, 1)) - elif '.' in name: - extension_module_name, name = name.split('.') - _mod = getattr(mod, extension_module_name) - else: - _mod = mod - stub_func = stubs.name_to_func[name] - - if not hasattr(_mod, name): - pytest.skip(f"{mod_name} does not have {name}(), skipping.") - if stub_func is None: - # TODO: Can we make this skip the parameterization entirely? - pytest.skip(f"{name} is not a function, skipping.") - mod_func = getattr(_mod, name) - argspec = inspect.getfullargspec(stub_func) - args = argspec.args - if args[:1] == ['self']: - args = args[1:] - kwonlyargs = argspec.kwonlyargs - kwonlydefaults = argspec.kwonlydefaults or {} - dtype = None - - args = [example_argument(arg, name, dtype) for arg in args] - - for arg in kwonlyargs: - value = example_argument(arg, name, dtype) - # The "only" part of keyword-only is tested by the positional test above. - doesnt_raise(lambda: mod_func(*args, **{arg: value}), - f"{name}() should accept the keyword-only argument {arg!r}") - - # Make sure the default is accepted. These tests are not granular - # enough to test that the default is actually the default, i.e., gives - # the same value if the keyword isn't passed. That is tested in the - # specific function tests. - if arg in kwonlydefaults: - default_value = kwonlydefaults[arg] - doesnt_raise(lambda: mod_func(*args, **{arg: default_value}), - f"{name}() should accept the default value {default_value!r} for the keyword-only argument {arg!r}") + pytest.skip( + skip_msg + f" (because no default was found for argument {param.name})" + ) + param_to_value[param] = value + + args: List[Any] = [ + v for p, v in param_to_value.items() if p.kind == Parameter.POSITIONAL_ONLY + ] + kwargs: Dict[str, Any] = { + p.name: v for p, v in param_to_value.items() if p.kind == Parameter.KEYWORD_ONLY + } + f_func = make_pretty_func(func_name, *args, **kwargs) + note(f"trying {f_func}") + func(*args, **kwargs) + + +def _test_func_signature(func: Callable, stub: FunctionType, array=None): + stub_sig = signature(stub) + # If testing against array, ignore 'self' arg in stub as it won't be present + # in func (which should be a method). + if array is not None: + stub_params = list(stub_sig.parameters.values()) + del stub_params[0] + stub_sig = Signature( + parameters=stub_params, return_annotation=stub_sig.return_annotation + ) + + try: + sig = signature(func) + _test_inspectable_func(sig, stub_sig) + except ValueError: + _test_uninspectable_func(stub.__name__, func, stub_sig, array) + + +@pytest.mark.parametrize( + "stub", + [s for stubs in category_to_funcs.values() for s in stubs], + ids=lambda f: f.__name__, +) +def test_func_signature(stub: FunctionType): + assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module" + func = getattr(xp, stub.__name__) + _test_func_signature(func, stub) + + +extension_and_stub_params = [] +for ext, stubs in extension_to_funcs.items(): + for stub in stubs: + p = pytest.param( + ext, stub, id=f"{ext}.{stub.__name__}", marks=pytest.mark.xp_extension(ext) + ) + extension_and_stub_params.append(p) + + +@pytest.mark.parametrize("extension, stub", extension_and_stub_params) +def test_extension_func_signature(extension: str, stub: FunctionType): + mod = getattr(xp, extension) + assert hasattr( + mod, stub.__name__ + ), f"{stub.__name__} not found in {extension} extension" + func = getattr(mod, stub.__name__) + _test_func_signature(func, stub) + + +@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__) +@given(st.data()) +@settings(max_examples=1) +def test_array_method_signature(stub: FunctionType, data: DataObject): + dtypes = get_dtypes_strategy(stub.__name__) + x = data.draw(xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label="x") + assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}" + method = getattr(x, stub.__name__) + _test_func_signature(method, stub, array=x)