From 8bd7bc076a5b84c9fb8fac363a93c4a04da79330 Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Mon, 17 Aug 2020 16:58:07 +0300 Subject: [PATCH 01/12] add automatic generation of type checks for overload_list --- numba_typing/overload_list.py | 236 ++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 numba_typing/overload_list.py diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py new file mode 100644 index 000000000..e1fd2d91b --- /dev/null +++ b/numba_typing/overload_list.py @@ -0,0 +1,236 @@ +import numpy +import numba +from numba import types +from numba import typeof +from numba.extending import overload +from type_annotations import product_annotations, get_func_annotations +from numba import njit +import typing +from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning +import warnings +from numba.typed import List, Dict + +warnings.simplefilter('ignore', category=NumbaDeprecationWarning) +warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) + + +def overload_list(orig_func): + def overload_inner(ovld_list): + def wrapper(a, b=0): + func_list = ovld_list() + sig_list = [] + for func in func_list: + sig_list.append((product_annotations( + get_func_annotations(func)), func)) + result = choose_func_by_sig(sig_list, a=a, b=b) + + if result is None: + raise numba.TypingError(f'Unsupported types a={a}, b={b}') + + result.__annotations__.clear() + return result + + return overload(orig_func)(wrapper) + + return overload_inner + + +T = typing.TypeVar('T') +K = typing.TypeVar('K') + + +class TypeChecker: + def __init__(self): + self._types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type, + str: check_str_type, list: check_list_type, + tuple: check_tuple_type, dict: check_dict_type, T: check_T_type} + self._typevars_dict = {} + + def add_type_check(self, type_check, func): + self._types_dict[type_check] = func + + def _is_generic(self, p_obj): + if isinstance(p_obj, typing._GenericAlias): + return True + + if isinstance(p_obj, typing._SpecialForm): + return p_obj not in {typing.Any} + + return False + + def _get_origin(self, p_obj): + return p_obj.__origin__ + + def match(self, p_type, n_type): + if p_type == typing.Any: + return True + elif self._is_generic(p_type): + origin_type = self._get_origin(p_type) + if origin_type == typing.Generic: + return self.match_generic(p_type, n_type) + else: + check = self._types_dict.get(origin_type) + elif isinstance(p_type, typing.TypeVar): + return self.match_typevar(p_type, n_type) + else: + check = self._types_dict.get(p_type) + if not check: + raise ValueError(f'A check for the {p_type} was not found') + # fix + return check(n_type) if check.__code__.co_argcount < 2 else check(self, p_type, n_type) + + def match_typevar(self, p_type, n_type): + if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): + self._typevars_dict[p_type] = n_type + return True + return self._typevars_dict.get(p_type) == n_type + + def match_generic(self, p_type, n_type): + res = False + for arg in p_type.__args__: + res = res or self.match(arg, n_type) + return res + + +def check_int_type(n_type): + return isinstance(n_type, types.Integer) + + +def check_T_type(n_type): + return True + + +def check_float_type(n_type): + return isinstance(n_type, types.Float) + + +def check_bool_type(n_type): + return isinstance(n_type, types.Boolean) + + +def check_str_type(n_type): + return isinstance(n_type, types.UnicodeType) + + +def check_list_type(self, p_type, n_type): + res = isinstance(n_type, types.List) or isinstance(n_type, types.ListType) + if isinstance(p_type, type): + return res + else: + return res and self.match(p_type.__args__[0], n_type.dtype) # fix + + +def check_tuple_type(self, p_type, n_type): + res = False + if isinstance(n_type, types.Tuple): + res = True + if isinstance(p_type, type): + return res + for p_val, n_val in zip(p_type.__args__, n_type.key): + res = res and self.match(p_val, n_val) + if isinstance(n_type, types.UniTuple): + res = True + if isinstance(p_type, type): + return res + for p_val in p_type.__args__: + res = res and self.match(p_val, n_type.key[0]) + return res + + +def check_dict_type(self, p_type, n_type): + res = False + if isinstance(n_type, types.DictType): + res = True + if isinstance(p_type, type): + return res + for p_val, n_val in zip(p_type.__args__, n_type.keyvalue_type): + res = res and self.match(p_val, n_val) + return res + + +def choose_func_by_sig(sig_list, **kwargs): + checker = TypeChecker() + for sig in sig_list: # sig = (Signature,func) + for param in sig[0].parameters: # param = {'a':int,'b':int} + full_match = True + for name, typ in kwargs.items(): # name,type = 'a',int64 + if isinstance(typ, types.Literal): + + full_match = full_match and checker.match( + param[name], typ.literal_type) + + if sig[0].defaults.get(name, False): + full_match = full_match and sig[0].defaults[name] == typ.literal_value + else: + # full_match = True + full_match = full_match and checker.match(param[name], typ) + + if not full_match: + break + if full_match: + return sig[1] + + return None + + +def foo(a, b=0): + ... + + +@overload_list(foo) +def foo_ovld_list(): + + def foo_int(a: int, b: int = 0): + return a+b + + def foo_float(a: float, b: float = 0): + return a*b + + def foo_bool(a: bool, b: int = 0): + return ('bool', a) + + def foo_str(a: str, b: int = 0): + return ('str', a) + + def foo_list(a: typing.List[int], b: int = 0): + return ('list', a) + + def foo_tuple(a: typing.Tuple[float], b: int = 0): + return ('tuple', a) + + def foo_dict(a: typing.Dict[str, int], b: int = 0): + return ('dict', a) + + def foo_any(a: typing.Any, b: int = 0): + return('any', a) + + def foo_union(a: typing.Union[int, str], b: int = 0): + return('union', a) + + def foo_optional(a: typing.Optional[int], b: int = 0): + return('optional', a) + + def foo_typevar(a: typing.List[typing.List[int]], b: int = 0): + return('typevar', a) + + def foo_typevars(a: T, b: K = 0): + return('TypeVars', a, b) + + def foo_generic(a: typing.Generic[T], b: typing.Generic[T] = 0): + return('Generic', a, b) + + # return foo_int,foo_float, foo_bool, foo_str, foo_list, foo_tuple, foo_dict, foo_any + return foo_list, foo_typevar + + +if __name__ == '__main__': + + @njit + def myfunc(a, b=0): + return foo(a, b) + + V = List() + V.append(List([1, 2])) + F = 7 + + print(myfunc(V, F)) \ No newline at end of file From d3f4a5d3043552ad9392009f8a07a0b5bf231d48 Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Mon, 17 Aug 2020 18:33:20 +0300 Subject: [PATCH 02/12] add exception handling to the match function and correct handling of the generic type --- numba_typing/overload_list.py | 49 +++++++++++++++++------------------ 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py index e1fd2d91b..c576df5bf 100644 --- a/numba_typing/overload_list.py +++ b/numba_typing/overload_list.py @@ -62,22 +62,21 @@ def _get_origin(self, p_obj): return p_obj.__origin__ def match(self, p_type, n_type): - if p_type == typing.Any: - return True - elif self._is_generic(p_type): - origin_type = self._get_origin(p_type) - if origin_type == typing.Generic: - return self.match_generic(p_type, n_type) + try: + if p_type == typing.Any: + return True + elif self._is_generic(p_type): + origin_type = self._get_origin(p_type) + if origin_type == typing.Generic: + return self.match_generic(p_type, n_type) + else: + return self._types_dict[origin_type](self, p_type, n_type) + elif isinstance(p_type, typing.TypeVar): + return self.match_typevar(p_type, n_type) else: - check = self._types_dict.get(origin_type) - elif isinstance(p_type, typing.TypeVar): - return self.match_typevar(p_type, n_type) - else: - check = self._types_dict.get(p_type) - if not check: - raise ValueError(f'A check for the {p_type} was not found') - # fix - return check(n_type) if check.__code__.co_argcount < 2 else check(self, p_type, n_type) + return self._types_dict[p_type](n_type) + except KeyError: + print((f'A check for the {p_type} was not found')) def match_typevar(self, p_type, n_type): if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): @@ -86,9 +85,9 @@ def match_typevar(self, p_type, n_type): return self._typevars_dict.get(p_type) == n_type def match_generic(self, p_type, n_type): - res = False + res = True for arg in p_type.__args__: - res = res or self.match(arg, n_type) + res = res and self.match(arg, n_type) return res @@ -117,7 +116,7 @@ def check_list_type(self, p_type, n_type): if isinstance(p_type, type): return res else: - return res and self.match(p_type.__args__[0], n_type.dtype) # fix + return res and self.match(p_type.__args__[0], n_type.dtype) def check_tuple_type(self, p_type, n_type): @@ -162,7 +161,6 @@ def choose_func_by_sig(sig_list, **kwargs): if sig[0].defaults.get(name, False): full_match = full_match and sig[0].defaults[name] == typ.literal_value else: - # full_match = True full_match = full_match and checker.match(param[name], typ) if not full_match: @@ -210,17 +208,17 @@ def foo_union(a: typing.Union[int, str], b: int = 0): def foo_optional(a: typing.Optional[int], b: int = 0): return('optional', a) - def foo_typevar(a: typing.List[typing.List[int]], b: int = 0): + def foo_list_in_list(a: typing.List[typing.List[int]], b: int = 0): return('typevar', a) def foo_typevars(a: T, b: K = 0): return('TypeVars', a, b) - def foo_generic(a: typing.Generic[T], b: typing.Generic[T] = 0): + def foo_generic(a: typing.Generic[T], b: typing.Generic[K, T] = 0): return('Generic', a, b) # return foo_int,foo_float, foo_bool, foo_str, foo_list, foo_tuple, foo_dict, foo_any - return foo_list, foo_typevar + return foo_list, foo_generic if __name__ == '__main__': @@ -229,8 +227,9 @@ def foo_generic(a: typing.Generic[T], b: typing.Generic[T] = 0): def myfunc(a, b=0): return foo(a, b) - V = List() - V.append(List([1, 2])) + # V = List() + # V.append(List([1, 2])) + V = 5.0 F = 7 - print(myfunc(V, F)) \ No newline at end of file + print(myfunc(V, F)) From b54da175c968f315c5ae2ca8a2e066ca00acde5c Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Fri, 4 Sep 2020 12:23:49 +0300 Subject: [PATCH 03/12] add unittests --- numba_typing/overload_list.py | 93 +++------------- numba_typing/test_overload_list.py | 170 +++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+), 80 deletions(-) create mode 100644 numba_typing/test_overload_list.py diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py index c576df5bf..6b0a86f91 100644 --- a/numba_typing/overload_list.py +++ b/numba_typing/overload_list.py @@ -9,6 +9,8 @@ from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning import warnings from numba.typed import List, Dict +from inspect import getfullargspec + warnings.simplefilter('ignore', category=NumbaDeprecationWarning) warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) @@ -16,34 +18,31 @@ def overload_list(orig_func): def overload_inner(ovld_list): - def wrapper(a, b=0): + def wrapper(*args): func_list = ovld_list() sig_list = [] for func in func_list: sig_list.append((product_annotations( get_func_annotations(func)), func)) - result = choose_func_by_sig(sig_list, a=a, b=b) + param = getfullargspec(orig_func).args + kwargs = {name: typ for name, typ in zip(param, args)} + result = choose_func_by_sig(sig_list, **kwargs) if result is None: raise numba.TypingError(f'Unsupported types a={a}, b={b}') - result.__annotations__.clear() return result - return overload(orig_func)(wrapper) + return overload(orig_func, strict=False)(wrapper) return overload_inner -T = typing.TypeVar('T') -K = typing.TypeVar('K') - - class TypeChecker: def __init__(self): self._types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type, str: check_str_type, list: check_list_type, - tuple: check_tuple_type, dict: check_dict_type, T: check_T_type} + tuple: check_tuple_type, dict: check_dict_type} self._typevars_dict = {} def add_type_check(self, type_check, func): @@ -74,6 +73,8 @@ def match(self, p_type, n_type): elif isinstance(p_type, typing.TypeVar): return self.match_typevar(p_type, n_type) else: + if p_type in (list, tuple): + return self._types_dict[p_type](self, p_type, n_type) return self._types_dict[p_type](n_type) except KeyError: print((f'A check for the {p_type} was not found')) @@ -95,10 +96,6 @@ def check_int_type(n_type): return isinstance(n_type, types.Integer) -def check_T_type(n_type): - return True - - def check_float_type(n_type): return isinstance(n_type, types.Float) @@ -148,8 +145,8 @@ def check_dict_type(self, p_type, n_type): def choose_func_by_sig(sig_list, **kwargs): - checker = TypeChecker() for sig in sig_list: # sig = (Signature,func) + checker = TypeChecker() for param in sig[0].parameters: # param = {'a':int,'b':int} full_match = True for name, typ in kwargs.items(): # name,type = 'a',int64 @@ -165,71 +162,7 @@ def choose_func_by_sig(sig_list, **kwargs): if not full_match: break - if full_match: - return sig[1] + if full_match: + return sig[1] return None - - -def foo(a, b=0): - ... - - -@overload_list(foo) -def foo_ovld_list(): - - def foo_int(a: int, b: int = 0): - return a+b - - def foo_float(a: float, b: float = 0): - return a*b - - def foo_bool(a: bool, b: int = 0): - return ('bool', a) - - def foo_str(a: str, b: int = 0): - return ('str', a) - - def foo_list(a: typing.List[int], b: int = 0): - return ('list', a) - - def foo_tuple(a: typing.Tuple[float], b: int = 0): - return ('tuple', a) - - def foo_dict(a: typing.Dict[str, int], b: int = 0): - return ('dict', a) - - def foo_any(a: typing.Any, b: int = 0): - return('any', a) - - def foo_union(a: typing.Union[int, str], b: int = 0): - return('union', a) - - def foo_optional(a: typing.Optional[int], b: int = 0): - return('optional', a) - - def foo_list_in_list(a: typing.List[typing.List[int]], b: int = 0): - return('typevar', a) - - def foo_typevars(a: T, b: K = 0): - return('TypeVars', a, b) - - def foo_generic(a: typing.Generic[T], b: typing.Generic[K, T] = 0): - return('Generic', a, b) - - # return foo_int,foo_float, foo_bool, foo_str, foo_list, foo_tuple, foo_dict, foo_any - return foo_list, foo_generic - - -if __name__ == '__main__': - - @njit - def myfunc(a, b=0): - return foo(a, b) - - # V = List() - # V.append(List([1, 2])) - V = 5.0 - F = 7 - - print(myfunc(V, F)) diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py new file mode 100644 index 000000000..faa623141 --- /dev/null +++ b/numba_typing/test_overload_list.py @@ -0,0 +1,170 @@ +import overload_list +from overload_list import List, Dict +from overload_list import types +import unittest +import typing +from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning +import warnings + + +def func(a, b): + ... + + +T = typing.TypeVar('T') +K = typing.TypeVar('K') + + +@overload_list.overload_list(func) +def foo_ovld_list(): + + def foo_int(a: int, b: int = 0): + return ('int', a, b) + + def foo_float(a: float, b: float = 0.0): + return ('float', a, b) + + def foo_bool(a: bool, b: bool = False): + return ('bool', a, b) + + def foo_str(a: str, b: str = '0'): + return ('str', a, b) + + def foo_list(a: typing.List[int], b: list = [0, 0, 0]): + return ('list', a, b) + + def foo_tuple(a: typing.Tuple[int, float], b: tuple = (0, 0)): + return ('tuple', a, b) + + def foo_dict(a: typing.Dict[str, int], b: typing.Dict[int, bool] = {0: False}): + return ('dict', a, b) + + # def foo_any(a: typing.Any, b: typing.Any = None): + # return('any', a, b) + + def foo_union(a: typing.Union[int, str], b: typing.Union[float, bool] = None): + return('union', a, b) + + def foo_optional(a: typing.Optional[float], b: typing.Optional[str] = None): + return('optional', a, b) + + def foo_list_in_list(a: typing.List[typing.List[int]], b: typing.List[typing.List[typing.List[float]]] = [[[0.0, 0.0]]]): + return('list_in_list', a, b) + + def foo_tuple_in_tuple(a: typing.Tuple[typing.Tuple[int]], b: typing.Tuple[typing.Tuple[typing.Tuple[float]]] = ((0.0, 0.0))): + return('tuple_in_tuple', a, b) + + def foo_typevars_T_T(a: T, b: T): + return('TypeVars_TT', a, b) + + def foo_typevars_T_K(a: T, b: K): + return('TypeVars_TK', a, b) + + def foo_typevars_list_T(a: typing.List[T], b: T): + return('TypeVars_ListT', a, b) + + def foo_typevars_list_dict(a: typing.List[T], b: typing.Dict[K, T]): + return('TypeVars_List_Dict', a, b) + + def foo_typevars_list_dict_list(a: typing.List[T], b: typing.Dict[K, typing.List[T]]): + return('TypeVars_List_Dict_List', a, b) + + return foo_int, foo_float, foo_bool, foo_str, foo_list, foo_tuple, foo_dict, foo_union, foo_optional, foo_list_in_list, foo_tuple_in_tuple, foo_typevars_T_T, foo_typevars_list_T, foo_typevars_list_dict, foo_typevars_list_dict_list, foo_typevars_T_K + + +@overload_list.njit +def jit_func(a, b): + return func(a, b) + + +class TestOverloadList(unittest.TestCase): + maxDiff = None + + def test_myfunc_int_type(self): + self.assertEqual(jit_func(1, 2), ('int', 1, 2)) + + def test_myfunc_float_type(self): + self.assertEqual(jit_func(1.0, 2.0), ('float', 1.0, 2.0)) + + def test_myfunc_bool_type(self): + self.assertEqual(jit_func(True, True), ('bool', True, True)) + + def test_myfunc_str_type(self): + self.assertEqual(jit_func('qwe', 'qaz'), ('str', 'qwe', 'qaz')) + + def test_myfunc_list_type(self): + self.assertEqual(jit_func([1, 2], [3, 4]), ('list', [1, 2], [3, 4])) + + def test_myfunc_List_typed(self): + + warnings.simplefilter('ignore', category=NumbaDeprecationWarning) + warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) + L = List([1, 2, 3]) + self.assertEqual(jit_func(L, [3, 4]), ('list', L, [3, 4])) + + def test_myfunc_tuple_type(self): + self.assertEqual(jit_func((1, 2.0), ('3', False)), ('tuple', (1, 2.0), ('3', False))) + + def test_myfunc_dict_type(self): + D = Dict.empty(key_type=types.unicode_type, value_type=types.int64) + D_1 = Dict.empty(key_type=types.int64, value_type=types.boolean) + D['qwe'] = 1 + D['qaz'] = 2 + D_1[1] = True + D_1[0] = False + self.assertEqual(jit_func(D, D_1), ('dict', D, D_1)) + + # def test_myfunc_any_typing(self): + # self.assertEqual(jit_func((1,2.0),['qaz','qwe']), ('any',(1,2.0),['qaz','qwe'])) + + def test_myfunc_union_typing_int_bool(self): + self.assertEqual(jit_func(1, False), ('union', 1, False)) + + def test_myfunc_union_typing_str_bool(self): + self.assertEqual(jit_func('qwe', False), ('union', 'qwe', False)) + + def test_myfunc_union_typing_int_float(self): + self.assertEqual(jit_func(1, 2.0), ('union', 1, 2.0)) + + def test_myfunc_union_typing_str_float(self): + self.assertEqual(jit_func('qwe', 2.0), ('union', 'qwe', 2.0)) + + def test_myfunc_optional_typing(self): + self.assertEqual(jit_func(1.0, 'qwe'), ('optional', 1.0, 'qwe')) + + def test_myfunc_list_in_list_type(self): + L_int = List([List([1, 2])]) + L_float = List([List([List([3.0, 4.0])])]) + + self.assertEqual(jit_func(L_int, L_float), ('list_in_list', L_int, L_float)) + + def test_myfunc_tuple_in_tuple(self): + self.assertEqual(jit_func(((1, 2),), (((3.0, 4.0),),)), ('tuple_in_tuple', ((1, 2),), (((3.0, 4.0),),))) + + def test_myfunc_typevar_T_T(self): + self.assertEqual(jit_func(((1, 2),), ((3, 4),)), ('TypeVars_TT', ((1, 2),), ((3, 4),))) + + def test_myfunc_typevar_T_K(self): + self.assertEqual(jit_func(1.0, 2), ('TypeVars_TK', 1.0, 2)) + + def test_myfunc_typevar_List_T(self): + L_int = List([1, 2]) + self.assertEqual(jit_func(L_int, 2), ('TypeVars_ListT', L_int, 2)) + + def test_myfunc_typevar_List_Dict(self): + D = Dict.empty(key_type=types.unicode_type, value_type=types.int64) + D['qwe'] = 0 + L_int = List([1, 2]) + self.assertEqual(jit_func(L_int, D), ('TypeVars_List_Dict', L_int, D)) + + def test_myfunc_typevar_List_Dict_List(self): + list_type = types.ListType(types.int64) + D = Dict.empty(key_type=types.unicode_type, value_type=list_type) + D['qwe'] = List([3, 4, 5]) + L_int = List([1, 2]) + + self.assertEqual(jit_func(L_int, D), ('TypeVars_List_Dict_List', L_int, D)) + + +if __name__ == "__main__": + unittest.main() From 05c745ba75619286d6efa6f640a8d4ae1d93ec9d Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Wed, 9 Sep 2020 01:19:50 +0300 Subject: [PATCH 04/12] add default value processing and tests --- numba_typing/overload_list.py | 135 ++++++++++++++++------------- numba_typing/test_overload_list.py | 132 ++++++++++++++++++++++++++-- 2 files changed, 200 insertions(+), 67 deletions(-) diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py index 6b0a86f91..7a7d0ceaa 100644 --- a/numba_typing/overload_list.py +++ b/numba_typing/overload_list.py @@ -24,9 +24,13 @@ def wrapper(*args): for func in func_list: sig_list.append((product_annotations( get_func_annotations(func)), func)) - param = getfullargspec(orig_func).args - kwargs = {name: typ for name, typ in zip(param, args)} - result = choose_func_by_sig(sig_list, **kwargs) + args_orig_func = getfullargspec(orig_func) + values_dict = {name: typ for name, typ in zip(args_orig_func.args, args)} + defaults_dict = {} + if args_orig_func.defaults: + defaults_dict = {name: value for name, value in zip( + args_orig_func.args[::-1], args_orig_func.defaults[::-1])} + result = choose_func_by_sig(sig_list, values_dict, defaults_dict) if result is None: raise numba.TypingError(f'Unsupported types a={a}, b={b}') @@ -38,60 +42,6 @@ def wrapper(*args): return overload_inner -class TypeChecker: - def __init__(self): - self._types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type, - str: check_str_type, list: check_list_type, - tuple: check_tuple_type, dict: check_dict_type} - self._typevars_dict = {} - - def add_type_check(self, type_check, func): - self._types_dict[type_check] = func - - def _is_generic(self, p_obj): - if isinstance(p_obj, typing._GenericAlias): - return True - - if isinstance(p_obj, typing._SpecialForm): - return p_obj not in {typing.Any} - - return False - - def _get_origin(self, p_obj): - return p_obj.__origin__ - - def match(self, p_type, n_type): - try: - if p_type == typing.Any: - return True - elif self._is_generic(p_type): - origin_type = self._get_origin(p_type) - if origin_type == typing.Generic: - return self.match_generic(p_type, n_type) - else: - return self._types_dict[origin_type](self, p_type, n_type) - elif isinstance(p_type, typing.TypeVar): - return self.match_typevar(p_type, n_type) - else: - if p_type in (list, tuple): - return self._types_dict[p_type](self, p_type, n_type) - return self._types_dict[p_type](n_type) - except KeyError: - print((f'A check for the {p_type} was not found')) - - def match_typevar(self, p_type, n_type): - if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): - self._typevars_dict[p_type] = n_type - return True - return self._typevars_dict.get(p_type) == n_type - - def match_generic(self, p_type, n_type): - res = True - for arg in p_type.__args__: - res = res and self.match(arg, n_type) - return res - - def check_int_type(n_type): return isinstance(n_type, types.Integer) @@ -144,12 +94,71 @@ def check_dict_type(self, p_type, n_type): return res -def choose_func_by_sig(sig_list, **kwargs): +class TypeChecker: + + _types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type, + str: check_str_type, list: check_list_type, + tuple: check_tuple_type, dict: check_dict_type} + + def __init__(self): + self._typevars_dict = {} + + def clear_typevars_dict(self): + self._typevars_dict.clear() + + def add_type_check(self, type_check, func): + self._types_dict[type_check] = func + + def _is_generic(self, p_obj): + if isinstance(p_obj, typing._GenericAlias): + return True + + if isinstance(p_obj, typing._SpecialForm): + return p_obj not in {typing.Any} + + return False + + def _get_origin(self, p_obj): + return p_obj.__origin__ + + def match(self, p_type, n_type): + try: + if p_type == typing.Any: + return True + elif self._is_generic(p_type): + origin_type = self._get_origin(p_type) + if origin_type == typing.Generic: + return self.match_generic(p_type, n_type) + else: + return self._types_dict[origin_type](self, p_type, n_type) + elif isinstance(p_type, typing.TypeVar): + return self.match_typevar(p_type, n_type) + else: + if p_type in (list, tuple): + return self._types_dict[p_type](self, p_type, n_type) + return self._types_dict[p_type](n_type) + except KeyError: + print((f'A check for the {p_type} was not found. {n_type}')) + + def match_typevar(self, p_type, n_type): + if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): + self._typevars_dict[p_type] = n_type + return True + return self._typevars_dict.get(p_type) == n_type + + def match_generic(self, p_type, n_type): + res = True + for arg in p_type.__args__: + res = res and self.match(arg, n_type) + return res + + +def choose_func_by_sig(sig_list, values_dict, defaults_dict={}): + checker = TypeChecker() for sig in sig_list: # sig = (Signature,func) - checker = TypeChecker() for param in sig[0].parameters: # param = {'a':int,'b':int} full_match = True - for name, typ in kwargs.items(): # name,type = 'a',int64 + for name, typ in values_dict.items(): # name,type = 'a',int64 if isinstance(typ, types.Literal): full_match = full_match and checker.match( @@ -162,6 +171,12 @@ def choose_func_by_sig(sig_list, **kwargs): if not full_match: break + + for name, val in defaults_dict.items(): + if sig[0].defaults.get(name) != None: + full_match = full_match and sig[0].defaults[name] == val + + checker.clear_typevars_dict() if full_match: return sig[1] diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py index faa623141..286547975 100644 --- a/numba_typing/test_overload_list.py +++ b/numba_typing/test_overload_list.py @@ -45,13 +45,15 @@ def foo_dict(a: typing.Dict[str, int], b: typing.Dict[int, bool] = {0: False}): def foo_union(a: typing.Union[int, str], b: typing.Union[float, bool] = None): return('union', a, b) - def foo_optional(a: typing.Optional[float], b: typing.Optional[str] = None): - return('optional', a, b) + # def foo_optional(a: typing.Optional[float], b: typing.Optional[str] = None): + # return('optional', a, b) - def foo_list_in_list(a: typing.List[typing.List[int]], b: typing.List[typing.List[typing.List[float]]] = [[[0.0, 0.0]]]): + def foo_list_in_list(a: typing.List[typing.List[int]], + b: typing.List[typing.List[typing.List[float]]] = [[[0.0, 0.0]]]): return('list_in_list', a, b) - def foo_tuple_in_tuple(a: typing.Tuple[typing.Tuple[int]], b: typing.Tuple[typing.Tuple[typing.Tuple[float]]] = ((0.0, 0.0))): + def foo_tuple_in_tuple(a: typing.Tuple[typing.Tuple[int]], + b: typing.Tuple[typing.Tuple[typing.Tuple[float]]] = ((0.0, 0.0))): return('tuple_in_tuple', a, b) def foo_typevars_T_T(a: T, b: T): @@ -69,7 +71,9 @@ def foo_typevars_list_dict(a: typing.List[T], b: typing.Dict[K, T]): def foo_typevars_list_dict_list(a: typing.List[T], b: typing.Dict[K, typing.List[T]]): return('TypeVars_List_Dict_List', a, b) - return foo_int, foo_float, foo_bool, foo_str, foo_list, foo_tuple, foo_dict, foo_union, foo_optional, foo_list_in_list, foo_tuple_in_tuple, foo_typevars_T_T, foo_typevars_list_T, foo_typevars_list_dict, foo_typevars_list_dict_list, foo_typevars_T_K + return foo_int, foo_float, foo_bool, foo_str, foo_list, foo_tuple, foo_dict, foo_union,\ + foo_list_in_list, foo_tuple_in_tuple, foo_typevars_T_T, foo_typevars_list_T,\ + foo_typevars_list_dict, foo_typevars_list_dict_list, foo_typevars_T_K @overload_list.njit @@ -77,6 +81,120 @@ def jit_func(a, b): return func(a, b) +class TestOverloadListDefault(unittest.TestCase): + maxDiff = None + + def test_myfunc_int_type_default(self): + def foo(a, b=0): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo_int(a: int, b: int = 0): + return ('int', a, b) + + return (foo_int,) + + @overload_list.njit + def jit_func(a): + return foo(a) + + self.assertEqual(jit_func(1), ('int', 1, 0)) + + def test_myfunc_float_type_default(self): + def foo(a, b=0.0): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo_float(a: float, b: float = 0.0): + return ('float', a, b) + + return (foo_float,) + + @overload_list.njit + def jit_func(a): + return foo(a) + + self.assertEqual(jit_func(1.0), ('float', 1.0, 0.0)) + + def test_myfunc_bool_type_default(self): + def foo(a, b=False): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo_bool(a: bool, b: bool = False): + return ('bool', a, b) + + return (foo_bool,) + + @overload_list.njit + def jit_func(a): + return foo(a) + + self.assertEqual(jit_func(True), ('bool', True, False)) + + def test_myfunc_str_type_default(self): + def foo(a, b='0'): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo_str(a: str, b: str = '0'): + return ('str', a, b) + + return (foo_str,) + + @overload_list.njit + def jit_func(a): + return foo(a) + + self.assertEqual(jit_func('qwe'), ('str', 'qwe', '0')) + + # def test_myfunc_list_type_default(self): + # warnings.simplefilter('ignore', category=NumbaDeprecationWarning) + # warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) + # L = List([0, 0]) + # def foo(a,b=L): + # ... + # @overload_list.overload_list(foo) + # def foo_ovld_list(): + + # def foo_list(a: typing.List[int], b: typing.List[int] = [0,0]): + # return ('list', a, b) + + # return (foo_list,) + + # @overload_list.njit + # def jit_func(a): + # return foo(a) + + # self.assertEqual(jit_func([1,2]), ('list',[1,2],L)) + + def test_myfunc_tuple_type_default(self): + def foo(a, b=(0, 0)): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo_tuple(a: tuple, b: tuple = (0, 0)): + return ('tuple', a, b) + + return (foo_tuple,) + + @overload_list.njit + def jit_func(a): + return foo(a) + + self.assertEqual(jit_func((1, 2)), ('tuple', (1, 2), (0, 0))) + + class TestOverloadList(unittest.TestCase): maxDiff = None @@ -129,8 +247,8 @@ def test_myfunc_union_typing_int_float(self): def test_myfunc_union_typing_str_float(self): self.assertEqual(jit_func('qwe', 2.0), ('union', 'qwe', 2.0)) - def test_myfunc_optional_typing(self): - self.assertEqual(jit_func(1.0, 'qwe'), ('optional', 1.0, 'qwe')) + # def test_myfunc_optional_typing(self): + # self.assertEqual(jit_func(1.0, 'qwe'), ('optional', 1.0, 'qwe')) def test_myfunc_list_in_list_type(self): L_int = List([List([1, 2])]) From b7446cafb606cc946fbd5b4a7b4e1f82f8b48d53 Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Wed, 9 Sep 2020 17:46:30 +0300 Subject: [PATCH 05/12] fix comments --- numba_typing/overload_list.py | 37 ++++++++++++++---------------- numba_typing/test_overload_list.py | 16 ++++++------- 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py index 7a7d0ceaa..13f386e19 100644 --- a/numba_typing/overload_list.py +++ b/numba_typing/overload_list.py @@ -1,10 +1,7 @@ -import numpy import numba from numba import types -from numba import typeof from numba.extending import overload from type_annotations import product_annotations, get_func_annotations -from numba import njit import typing from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning import warnings @@ -12,10 +9,6 @@ from inspect import getfullargspec -warnings.simplefilter('ignore', category=NumbaDeprecationWarning) -warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) - - def overload_list(orig_func): def overload_inner(ovld_list): def wrapper(*args): @@ -106,10 +99,12 @@ def __init__(self): def clear_typevars_dict(self): self._typevars_dict.clear() - def add_type_check(self, type_check, func): - self._types_dict[type_check] = func + @classmethod + def add_type_check(cls, type_check, func): + cls._types_dict[type_check] = func - def _is_generic(self, p_obj): + @staticmethod + def _is_generic(p_obj): if isinstance(p_obj, typing._GenericAlias): return True @@ -118,7 +113,8 @@ def _is_generic(self, p_obj): return False - def _get_origin(self, p_obj): + @staticmethod + def _get_origin(p_obj): return p_obj.__origin__ def match(self, p_type, n_type): @@ -138,7 +134,8 @@ def match(self, p_type, n_type): return self._types_dict[p_type](self, p_type, n_type) return self._types_dict[p_type](n_type) except KeyError: - print((f'A check for the {p_type} was not found. {n_type}')) + print((f'A check for the {p_type} was not found.')) + return None def match_typevar(self, p_type, n_type): if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): @@ -153,10 +150,10 @@ def match_generic(self, p_type, n_type): return res -def choose_func_by_sig(sig_list, values_dict, defaults_dict={}): +def choose_func_by_sig(sig_list, values_dict, defaults_dict): checker = TypeChecker() - for sig in sig_list: # sig = (Signature,func) - for param in sig[0].parameters: # param = {'a':int,'b':int} + for sig, func in sig_list: # sig = (Signature,func) + for param in sig.parameters: # param = {'a':int,'b':int} full_match = True for name, typ in values_dict.items(): # name,type = 'a',int64 if isinstance(typ, types.Literal): @@ -164,8 +161,8 @@ def choose_func_by_sig(sig_list, values_dict, defaults_dict={}): full_match = full_match and checker.match( param[name], typ.literal_type) - if sig[0].defaults.get(name, False): - full_match = full_match and sig[0].defaults[name] == typ.literal_value + if sig.defaults.get(name, False): + full_match = full_match and sig.defaults[name] == typ.literal_value else: full_match = full_match and checker.match(param[name], typ) @@ -173,11 +170,11 @@ def choose_func_by_sig(sig_list, values_dict, defaults_dict={}): break for name, val in defaults_dict.items(): - if sig[0].defaults.get(name) != None: - full_match = full_match and sig[0].defaults[name] == val + if not sig.defaults.get(name) is None: + full_match = full_match and sig.defaults[name] == val checker.clear_typevars_dict() if full_match: - return sig[1] + return func return None diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py index 286547975..6f75c7d4e 100644 --- a/numba_typing/test_overload_list.py +++ b/numba_typing/test_overload_list.py @@ -3,7 +3,7 @@ from overload_list import types import unittest import typing -from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning +from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning, njit import warnings @@ -76,7 +76,7 @@ def foo_typevars_list_dict_list(a: typing.List[T], b: typing.Dict[K, typing.List foo_typevars_list_dict, foo_typevars_list_dict_list, foo_typevars_T_K -@overload_list.njit +@njit def jit_func(a, b): return func(a, b) @@ -96,7 +96,7 @@ def foo_int(a: int, b: int = 0): return (foo_int,) - @overload_list.njit + @njit def jit_func(a): return foo(a) @@ -114,7 +114,7 @@ def foo_float(a: float, b: float = 0.0): return (foo_float,) - @overload_list.njit + @njit def jit_func(a): return foo(a) @@ -132,7 +132,7 @@ def foo_bool(a: bool, b: bool = False): return (foo_bool,) - @overload_list.njit + @njit def jit_func(a): return foo(a) @@ -150,7 +150,7 @@ def foo_str(a: str, b: str = '0'): return (foo_str,) - @overload_list.njit + @njit def jit_func(a): return foo(a) @@ -170,7 +170,7 @@ def jit_func(a): # return (foo_list,) - # @overload_list.njit + # @njit # def jit_func(a): # return foo(a) @@ -188,7 +188,7 @@ def foo_tuple(a: tuple, b: tuple = (0, 0)): return (foo_tuple,) - @overload_list.njit + @njit def jit_func(a): return foo(a) From 09289bfb1acd0d5d25d525596c1b4ab866f1707d Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Mon, 14 Sep 2020 14:27:33 +0300 Subject: [PATCH 06/12] fix comments and add new tests --- numba_typing/overload_list.py | 100 ++++++++++++++--------------- numba_typing/test_overload_list.py | 81 ++++++++++++----------- 2 files changed, 95 insertions(+), 86 deletions(-) diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py index 13f386e19..61070845d 100644 --- a/numba_typing/overload_list.py +++ b/numba_typing/overload_list.py @@ -3,8 +3,6 @@ from numba.extending import overload from type_annotations import product_annotations, get_func_annotations import typing -from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning -import warnings from numba.typed import List, Dict from inspect import getfullargspec @@ -26,7 +24,7 @@ def wrapper(*args): result = choose_func_by_sig(sig_list, values_dict, defaults_dict) if result is None: - raise numba.TypingError(f'Unsupported types a={a}, b={b}') + raise TypeError(f'Unsupported types a={a}, b={b}') return result @@ -52,28 +50,27 @@ def check_str_type(n_type): def check_list_type(self, p_type, n_type): - res = isinstance(n_type, types.List) or isinstance(n_type, types.ListType) - if isinstance(p_type, type): + res = isinstance(n_type, (types.List, types.ListType)) + if p_type == list: return res else: return res and self.match(p_type.__args__[0], n_type.dtype) def check_tuple_type(self, p_type, n_type): - res = False - if isinstance(n_type, types.Tuple): - res = True - if isinstance(p_type, type): - return res - for p_val, n_val in zip(p_type.__args__, n_type.key): - res = res and self.match(p_val, n_val) - if isinstance(n_type, types.UniTuple): - res = True - if isinstance(p_type, type): - return res - for p_val in p_type.__args__: - res = res and self.match(p_val, n_type.key[0]) - return res + if not isinstance(n_type, (types.Tuple, types.UniTuple)): + return False + try: + if len(p_type.__args__) != len(n_type.types): + return False + except AttributeError: # if p_type == tuple + return True + + for p_val, n_val in zip(p_type.__args__, n_type.types): + if not self.match(p_val, n_val): + return False + + return True def check_dict_type(self, p_type, n_type): @@ -89,9 +86,7 @@ def check_dict_type(self, p_type, n_type): class TypeChecker: - _types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type, - str: check_str_type, list: check_list_type, - tuple: check_tuple_type, dict: check_dict_type} + _types_dict: dict = {} def __init__(self): self._typevars_dict = {} @@ -118,60 +113,65 @@ def _get_origin(p_obj): return p_obj.__origin__ def match(self, p_type, n_type): + if p_type == typing.Any: + return True try: - if p_type == typing.Any: - return True - elif self._is_generic(p_type): + if self._is_generic(p_type): origin_type = self._get_origin(p_type) if origin_type == typing.Generic: return self.match_generic(p_type, n_type) - else: - return self._types_dict[origin_type](self, p_type, n_type) - elif isinstance(p_type, typing.TypeVar): + + return self._types_dict[origin_type](self, p_type, n_type) + + if isinstance(p_type, typing.TypeVar): return self.match_typevar(p_type, n_type) - else: - if p_type in (list, tuple): - return self._types_dict[p_type](self, p_type, n_type) - return self._types_dict[p_type](n_type) + + if p_type in (list, tuple, dict): + return self._types_dict[p_type](self, p_type, n_type) + + return self._types_dict[p_type](n_type) + except KeyError: - print((f'A check for the {p_type} was not found.')) - return None + raise TypeError(f'A check for the {p_type} was not found.') def match_typevar(self, p_type, n_type): - if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): + if isinstance(n_type, types.List): + n_type = types.ListType(n_type.dtype) + if not self._typevars_dict.get(p_type): self._typevars_dict[p_type] = n_type return True return self._typevars_dict.get(p_type) == n_type def match_generic(self, p_type, n_type): - res = True - for arg in p_type.__args__: - res = res and self.match(arg, n_type) - return res + raise SystemError + + +TypeChecker.add_type_check(int, check_int_type) +TypeChecker.add_type_check(float, check_float_type) +TypeChecker.add_type_check(str, check_str_type) +TypeChecker.add_type_check(bool, check_bool_type) +TypeChecker.add_type_check(list, check_list_type) +TypeChecker.add_type_check(tuple, check_tuple_type) +TypeChecker.add_type_check(dict, check_dict_type) def choose_func_by_sig(sig_list, values_dict, defaults_dict): checker = TypeChecker() for sig, func in sig_list: # sig = (Signature,func) for param in sig.parameters: # param = {'a':int,'b':int} - full_match = True for name, typ in values_dict.items(): # name,type = 'a',int64 if isinstance(typ, types.Literal): + typ = typ.literal_type - full_match = full_match and checker.match( - param[name], typ.literal_type) - - if sig.defaults.get(name, False): - full_match = full_match and sig.defaults[name] == typ.literal_value - else: - full_match = full_match and checker.match(param[name], typ) + full_match = checker.match(param[name], typ) if not full_match: break - for name, val in defaults_dict.items(): - if not sig.defaults.get(name) is None: - full_match = full_match and sig.defaults[name] == val + if len(param) != len(values_dict.items()): + for name, val in defaults_dict.items(): + if not sig.defaults.get(name) is None: + full_match = full_match and sig.defaults[name] == val checker.clear_typevars_dict() if full_match: diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py index 6f75c7d4e..ed45af752 100644 --- a/numba_typing/test_overload_list.py +++ b/numba_typing/test_overload_list.py @@ -3,7 +3,7 @@ from overload_list import types import unittest import typing -from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning, njit +from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning, njit, core import warnings @@ -39,21 +39,15 @@ def foo_tuple(a: typing.Tuple[int, float], b: tuple = (0, 0)): def foo_dict(a: typing.Dict[str, int], b: typing.Dict[int, bool] = {0: False}): return ('dict', a, b) - # def foo_any(a: typing.Any, b: typing.Any = None): - # return('any', a, b) - def foo_union(a: typing.Union[int, str], b: typing.Union[float, bool] = None): return('union', a, b) - # def foo_optional(a: typing.Optional[float], b: typing.Optional[str] = None): - # return('optional', a, b) - def foo_list_in_list(a: typing.List[typing.List[int]], b: typing.List[typing.List[typing.List[float]]] = [[[0.0, 0.0]]]): return('list_in_list', a, b) - def foo_tuple_in_tuple(a: typing.Tuple[typing.Tuple[int]], - b: typing.Tuple[typing.Tuple[typing.Tuple[float]]] = ((0.0, 0.0))): + def foo_tuple_in_tuple(a: typing.Tuple[typing.Tuple[int, int]], + b: typing.Tuple[typing.Tuple[typing.Tuple[float, float]]] = ((0.0, 0.0))): return('tuple_in_tuple', a, b) def foo_typevars_T_T(a: T, b: T): @@ -84,6 +78,24 @@ def jit_func(a, b): class TestOverloadListDefault(unittest.TestCase): maxDiff = None + def test_myfunc_literal_type_default(self): + def foo(a, b=0): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo_int_literal(a: int, b: int = 0): + return ('literal', a, b) + + return (foo_int_literal,) + + @njit + def jit_func(a): + return foo(a, 2) + + self.assertEqual(jit_func(1), ('literal', 1, 2)) + def test_myfunc_int_type_default(self): def foo(a, b=0): ... @@ -156,26 +168,6 @@ def jit_func(a): self.assertEqual(jit_func('qwe'), ('str', 'qwe', '0')) - # def test_myfunc_list_type_default(self): - # warnings.simplefilter('ignore', category=NumbaDeprecationWarning) - # warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) - # L = List([0, 0]) - # def foo(a,b=L): - # ... - # @overload_list.overload_list(foo) - # def foo_ovld_list(): - - # def foo_list(a: typing.List[int], b: typing.List[int] = [0,0]): - # return ('list', a, b) - - # return (foo_list,) - - # @njit - # def jit_func(a): - # return foo(a) - - # self.assertEqual(jit_func([1,2]), ('list',[1,2],L)) - def test_myfunc_tuple_type_default(self): def foo(a, b=(0, 0)): ... @@ -194,6 +186,24 @@ def jit_func(a): self.assertEqual(jit_func((1, 2)), ('tuple', (1, 2), (0, 0))) + def test_myfunc_tuple_type_error(self): + def foo(a, b=(0, 0)): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo_tuple(a: typing.Tuple[int, int], b: tuple = (0, 0)): + return ('tuple_', a, b) + + return (foo_tuple,) + + @njit + def jit_func(a, b): + return foo(a, b) + + self.assertRaises(core.errors.TypingError, jit_func, (1, 2, 3), ('3', False)) + class TestOverloadList(unittest.TestCase): maxDiff = None @@ -214,7 +224,6 @@ def test_myfunc_list_type(self): self.assertEqual(jit_func([1, 2], [3, 4]), ('list', [1, 2], [3, 4])) def test_myfunc_List_typed(self): - warnings.simplefilter('ignore', category=NumbaDeprecationWarning) warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) L = List([1, 2, 3]) @@ -232,9 +241,6 @@ def test_myfunc_dict_type(self): D_1[0] = False self.assertEqual(jit_func(D, D_1), ('dict', D, D_1)) - # def test_myfunc_any_typing(self): - # self.assertEqual(jit_func((1,2.0),['qaz','qwe']), ('any',(1,2.0),['qaz','qwe'])) - def test_myfunc_union_typing_int_bool(self): self.assertEqual(jit_func(1, False), ('union', 1, False)) @@ -247,9 +253,6 @@ def test_myfunc_union_typing_int_float(self): def test_myfunc_union_typing_str_float(self): self.assertEqual(jit_func('qwe', 2.0), ('union', 'qwe', 2.0)) - # def test_myfunc_optional_typing(self): - # self.assertEqual(jit_func(1.0, 'qwe'), ('optional', 1.0, 'qwe')) - def test_myfunc_list_in_list_type(self): L_int = List([List([1, 2])]) L_float = List([List([List([3.0, 4.0])])]) @@ -262,6 +265,12 @@ def test_myfunc_tuple_in_tuple(self): def test_myfunc_typevar_T_T(self): self.assertEqual(jit_func(((1, 2),), ((3, 4),)), ('TypeVars_TT', ((1, 2),), ((3, 4),))) + def test_myfunc_typevar_T_T_list(self): + warnings.simplefilter('ignore', category=NumbaDeprecationWarning) + warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) + V = List([1.0, 2.0]) + self.assertEqual(jit_func(V, [3.0, 4.0]), ('TypeVars_TT', V, [3.0, 4.0])) + def test_myfunc_typevar_T_K(self): self.assertEqual(jit_func(1.0, 2), ('TypeVars_TK', 1.0, 2)) From 1cb60da757ef2fccd52f3944aa5d2b238a5228fd Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Tue, 15 Sep 2020 19:57:55 +0300 Subject: [PATCH 07/12] remove full_match and add a signature matching check --- numba_typing/overload_list.py | 49 +++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py index 61070845d..ec42b02f1 100644 --- a/numba_typing/overload_list.py +++ b/numba_typing/overload_list.py @@ -21,7 +21,8 @@ def wrapper(*args): if args_orig_func.defaults: defaults_dict = {name: value for name, value in zip( args_orig_func.args[::-1], args_orig_func.defaults[::-1])} - result = choose_func_by_sig(sig_list, values_dict, defaults_dict) + if valid_signature(sig_list, values_dict, defaults_dict): + result = choose_func_by_sig(sig_list, values_dict) if result is None: raise TypeError(f'Unsupported types a={a}, b={b}') @@ -33,6 +34,22 @@ def wrapper(*args): return overload_inner +def valid_signature(list_signature, values_dict, defaults_dict): + def check_defaults(sig_def): + for name, val in defaults_dict.items(): + if sig_def.get(name) is None: + raise AttributeError(f'{name} does not match the signature of the function passed to overload_list') + if not sig_def[name] == val: + raise ValueError(f'The default arguments are not equal: {name}: {val} != {sig_def[name]}') + + for sig, _ in list_signature: + for param in sig.parameters: + if len(param) != len(values_dict.items()): + check_defaults(sig.defaults) + + return True + + def check_int_type(n_type): return isinstance(n_type, types.Integer) @@ -155,26 +172,20 @@ def match_generic(self, p_type, n_type): TypeChecker.add_type_check(dict, check_dict_type) -def choose_func_by_sig(sig_list, values_dict, defaults_dict): - checker = TypeChecker() - for sig, func in sig_list: # sig = (Signature,func) - for param in sig.parameters: # param = {'a':int,'b':int} - for name, typ in values_dict.items(): # name,type = 'a',int64 - if isinstance(typ, types.Literal): - typ = typ.literal_type - - full_match = checker.match(param[name], typ) +def choose_func_by_sig(sig_list, values_dict): + def check_signature(sig_params, types_dict): + checker = TypeChecker() + for name, typ in types_dict.items(): # name,type = 'a',int64 + if isinstance(typ, types.Literal): + typ = typ.literal_type + if not checker.match(sig_params[name], typ): + return False - if not full_match: - break - - if len(param) != len(values_dict.items()): - for name, val in defaults_dict.items(): - if not sig.defaults.get(name) is None: - full_match = full_match and sig.defaults[name] == val + return True - checker.clear_typevars_dict() - if full_match: + for sig, func in sig_list: # sig = (Signature,func) + for param in sig.parameters: # param = {'a':int,'b':int} + if check_signature(param, values_dict): return func return None From 967ae2989499eb0c229ed061b2cededad47d64c0 Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Wed, 23 Sep 2020 02:43:46 +0300 Subject: [PATCH 08/12] add a test cases generator function --- numba_typing/test_overload_list.py | 335 ++++++++--------------------- 1 file changed, 92 insertions(+), 243 deletions(-) diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py index ed45af752..702f80a3c 100644 --- a/numba_typing/test_overload_list.py +++ b/numba_typing/test_overload_list.py @@ -3,79 +3,14 @@ from overload_list import types import unittest import typing -from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning, njit, core -import warnings - - -def func(a, b): - ... +from numba import njit, core T = typing.TypeVar('T') K = typing.TypeVar('K') -@overload_list.overload_list(func) -def foo_ovld_list(): - - def foo_int(a: int, b: int = 0): - return ('int', a, b) - - def foo_float(a: float, b: float = 0.0): - return ('float', a, b) - - def foo_bool(a: bool, b: bool = False): - return ('bool', a, b) - - def foo_str(a: str, b: str = '0'): - return ('str', a, b) - - def foo_list(a: typing.List[int], b: list = [0, 0, 0]): - return ('list', a, b) - - def foo_tuple(a: typing.Tuple[int, float], b: tuple = (0, 0)): - return ('tuple', a, b) - - def foo_dict(a: typing.Dict[str, int], b: typing.Dict[int, bool] = {0: False}): - return ('dict', a, b) - - def foo_union(a: typing.Union[int, str], b: typing.Union[float, bool] = None): - return('union', a, b) - - def foo_list_in_list(a: typing.List[typing.List[int]], - b: typing.List[typing.List[typing.List[float]]] = [[[0.0, 0.0]]]): - return('list_in_list', a, b) - - def foo_tuple_in_tuple(a: typing.Tuple[typing.Tuple[int, int]], - b: typing.Tuple[typing.Tuple[typing.Tuple[float, float]]] = ((0.0, 0.0))): - return('tuple_in_tuple', a, b) - - def foo_typevars_T_T(a: T, b: T): - return('TypeVars_TT', a, b) - - def foo_typevars_T_K(a: T, b: K): - return('TypeVars_TK', a, b) - - def foo_typevars_list_T(a: typing.List[T], b: T): - return('TypeVars_ListT', a, b) - - def foo_typevars_list_dict(a: typing.List[T], b: typing.Dict[K, T]): - return('TypeVars_List_Dict', a, b) - - def foo_typevars_list_dict_list(a: typing.List[T], b: typing.Dict[K, typing.List[T]]): - return('TypeVars_List_Dict_List', a, b) - - return foo_int, foo_float, foo_bool, foo_str, foo_list, foo_tuple, foo_dict, foo_union,\ - foo_list_in_list, foo_tuple_in_tuple, foo_typevars_T_T, foo_typevars_list_T,\ - foo_typevars_list_dict, foo_typevars_list_dict_list, foo_typevars_T_K - - -@njit -def jit_func(a, b): - return func(a, b) - - -class TestOverloadListDefault(unittest.TestCase): +class TestOverloadList(unittest.TestCase): maxDiff = None def test_myfunc_literal_type_default(self): @@ -96,96 +31,6 @@ def jit_func(a): self.assertEqual(jit_func(1), ('literal', 1, 2)) - def test_myfunc_int_type_default(self): - def foo(a, b=0): - ... - - @overload_list.overload_list(foo) - def foo_ovld_list(): - - def foo_int(a: int, b: int = 0): - return ('int', a, b) - - return (foo_int,) - - @njit - def jit_func(a): - return foo(a) - - self.assertEqual(jit_func(1), ('int', 1, 0)) - - def test_myfunc_float_type_default(self): - def foo(a, b=0.0): - ... - - @overload_list.overload_list(foo) - def foo_ovld_list(): - - def foo_float(a: float, b: float = 0.0): - return ('float', a, b) - - return (foo_float,) - - @njit - def jit_func(a): - return foo(a) - - self.assertEqual(jit_func(1.0), ('float', 1.0, 0.0)) - - def test_myfunc_bool_type_default(self): - def foo(a, b=False): - ... - - @overload_list.overload_list(foo) - def foo_ovld_list(): - - def foo_bool(a: bool, b: bool = False): - return ('bool', a, b) - - return (foo_bool,) - - @njit - def jit_func(a): - return foo(a) - - self.assertEqual(jit_func(True), ('bool', True, False)) - - def test_myfunc_str_type_default(self): - def foo(a, b='0'): - ... - - @overload_list.overload_list(foo) - def foo_ovld_list(): - - def foo_str(a: str, b: str = '0'): - return ('str', a, b) - - return (foo_str,) - - @njit - def jit_func(a): - return foo(a) - - self.assertEqual(jit_func('qwe'), ('str', 'qwe', '0')) - - def test_myfunc_tuple_type_default(self): - def foo(a, b=(0, 0)): - ... - - @overload_list.overload_list(foo) - def foo_ovld_list(): - - def foo_tuple(a: tuple, b: tuple = (0, 0)): - return ('tuple', a, b) - - return (foo_tuple,) - - @njit - def jit_func(a): - return foo(a) - - self.assertEqual(jit_func((1, 2)), ('tuple', (1, 2), (0, 0))) - def test_myfunc_tuple_type_error(self): def foo(a, b=(0, 0)): ... @@ -205,92 +50,96 @@ def jit_func(a, b): self.assertRaises(core.errors.TypingError, jit_func, (1, 2, 3), ('3', False)) -class TestOverloadList(unittest.TestCase): - maxDiff = None - - def test_myfunc_int_type(self): - self.assertEqual(jit_func(1, 2), ('int', 1, 2)) - - def test_myfunc_float_type(self): - self.assertEqual(jit_func(1.0, 2.0), ('float', 1.0, 2.0)) - - def test_myfunc_bool_type(self): - self.assertEqual(jit_func(True, True), ('bool', True, True)) - - def test_myfunc_str_type(self): - self.assertEqual(jit_func('qwe', 'qaz'), ('str', 'qwe', 'qaz')) - - def test_myfunc_list_type(self): - self.assertEqual(jit_func([1, 2], [3, 4]), ('list', [1, 2], [3, 4])) - - def test_myfunc_List_typed(self): - warnings.simplefilter('ignore', category=NumbaDeprecationWarning) - warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) - L = List([1, 2, 3]) - self.assertEqual(jit_func(L, [3, 4]), ('list', L, [3, 4])) - - def test_myfunc_tuple_type(self): - self.assertEqual(jit_func((1, 2.0), ('3', False)), ('tuple', (1, 2.0), ('3', False))) - - def test_myfunc_dict_type(self): - D = Dict.empty(key_type=types.unicode_type, value_type=types.int64) - D_1 = Dict.empty(key_type=types.int64, value_type=types.boolean) - D['qwe'] = 1 - D['qaz'] = 2 - D_1[1] = True - D_1[0] = False - self.assertEqual(jit_func(D, D_1), ('dict', D, D_1)) - - def test_myfunc_union_typing_int_bool(self): - self.assertEqual(jit_func(1, False), ('union', 1, False)) - - def test_myfunc_union_typing_str_bool(self): - self.assertEqual(jit_func('qwe', False), ('union', 'qwe', False)) - - def test_myfunc_union_typing_int_float(self): - self.assertEqual(jit_func(1, 2.0), ('union', 1, 2.0)) - - def test_myfunc_union_typing_str_float(self): - self.assertEqual(jit_func('qwe', 2.0), ('union', 'qwe', 2.0)) - - def test_myfunc_list_in_list_type(self): - L_int = List([List([1, 2])]) - L_float = List([List([List([3.0, 4.0])])]) - - self.assertEqual(jit_func(L_int, L_float), ('list_in_list', L_int, L_float)) - - def test_myfunc_tuple_in_tuple(self): - self.assertEqual(jit_func(((1, 2),), (((3.0, 4.0),),)), ('tuple_in_tuple', ((1, 2),), (((3.0, 4.0),),))) - - def test_myfunc_typevar_T_T(self): - self.assertEqual(jit_func(((1, 2),), ((3, 4),)), ('TypeVars_TT', ((1, 2),), ((3, 4),))) - - def test_myfunc_typevar_T_T_list(self): - warnings.simplefilter('ignore', category=NumbaDeprecationWarning) - warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) - V = List([1.0, 2.0]) - self.assertEqual(jit_func(V, [3.0, 4.0]), ('TypeVars_TT', V, [3.0, 4.0])) - - def test_myfunc_typevar_T_K(self): - self.assertEqual(jit_func(1.0, 2), ('TypeVars_TK', 1.0, 2)) - - def test_myfunc_typevar_List_T(self): - L_int = List([1, 2]) - self.assertEqual(jit_func(L_int, 2), ('TypeVars_ListT', L_int, 2)) - - def test_myfunc_typevar_List_Dict(self): - D = Dict.empty(key_type=types.unicode_type, value_type=types.int64) - D['qwe'] = 0 - L_int = List([1, 2]) - self.assertEqual(jit_func(L_int, D), ('TypeVars_List_Dict', L_int, D)) - - def test_myfunc_typevar_List_Dict_List(self): - list_type = types.ListType(types.int64) - D = Dict.empty(key_type=types.unicode_type, value_type=list_type) - D['qwe'] = List([3, 4, 5]) - L_int = List([1, 2]) - - self.assertEqual(jit_func(L_int, D), ('TypeVars_List_Dict_List', L_int, D)) +def generator_test(func_name, param, values_dict, defaults_dict={}): + + def check_type(typ): + if isinstance(typ, type): + return typ.__name__ + return typ + + value_keys = ", ".join("{}".format(key) for key in values_dict.keys()) + defaults_keys = ", ".join("{}".format(key) for key in defaults_dict.keys()) + value_str = ", ".join("{}: {}".format(key, check_type(val)) for key, val in values_dict.items()) + defaults_str = ", ".join("{} = {}".format(key, val) if not isinstance( + val, str) else "{} = '{}'".format(key, val) for key, val in defaults_dict.items()) + defaults_str_type = ", ".join("{}: {} = {}".format(key, check_type(type(val)), val) if not isinstance(val, str) + else "{}: {} = '{}'".format(key, check_type(type(val)), val) + for key, val in defaults_dict.items()) + value_type = ", ".join("{}".format(val) for val in values_dict.values()) + defaults_type = ", ".join("{}".format(type(val)) for val in defaults_dict.values()) + param_qwe = ", ".join("{}".format(i) for i in param) + test = f""" +def test_myfunc_{func_name}_type_default(self): + def foo({value_keys},{defaults_str}): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo_{func_name}({value_str},{defaults_str_type}): + return ("{value_type}","{defaults_type}") + + return (foo_{func_name},) + + @njit + def jit_func({value_keys},{defaults_str}): + return foo({value_keys},{defaults_keys}) + + self.assertEqual(jit_func({param_qwe}), ("{value_type}", "{defaults_type}")) +""" + loc = {} + exec(test, globals(), loc) + return loc + + +L = List([1, 2, 3]) +L_int = List([List([1, 2])]) +L_float = List([List([List([3.0, 4.0])])]) +L_f = List([1.0, 2.0]) +D = Dict.empty(key_type=types.unicode_type, value_type=types.int64) +D_1 = Dict.empty(key_type=types.int64, value_type=types.boolean) +D['qwe'] = 1 +D['qaz'] = 2 +D_1[1] = True +D_1[0] = False +list_type = types.ListType(types.int64) +D_list = Dict.empty(key_type=types.unicode_type, value_type=list_type) +D_list['qwe'] = List([3, 4, 5]) +str_1 = 'qwe' +str_2 = 'qaz' +test_cases = [('int', [1, 2], {'a': int, 'b': int}), ('float', [1.0, 2.0], {'a': float, 'b': float}), + ('bool', [True, True], {'a': bool, 'b': bool}), ('str', ['str_1', 'str_2'], {'a': str, 'b': str}), + ('list', [[1, 2], [3, 4]], {'a': typing.List[int], 'b':list}), + ('List_typed', [L, [3, 4]], {'a': typing.List[int], 'b':list}), + ('tuple', [(1, 2.0), ('3', False)], {'a': typing.Tuple[int, float], 'b':tuple}), + ('dict', ['D', 'D_1'], {'a': typing.Dict[str, int], 'b': typing.Dict[int, bool]}), + ('union_1', [1, False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), + ('union_2', ['str_1', False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), + ('nested_list', ['L_int', 'L_float'], {'a': typing.List[typing.List[int]], + 'b': typing.List[typing.List[typing.List[float]]]}), + ('TypeVar_TT', ['L_f', [3.0, 4.0]], {'a': 'T', 'b': 'T'}), + ('TypeVar_TK', [1.0, 2], {'a': 'T', 'b': 'K'}), + ('TypeVar_ListT_T', ['L', 5], {'a': 'typing.List[T]', 'b': 'T'}), + ('TypeVar_ListT_DictKT', ['L', 'D'], {'a': 'typing.List[T]', 'b': 'typing.Dict[K, T]'}), + ('TypeVar_ListT_DictK_ListT', ['L', 'D_list'], {'a': 'typing.List[T]', + 'b': 'typing.Dict[K, typing.List[T]]'})] + +test_cases_default = [('int_defaults', [1], {'a': int}, {'b': 0}), ('float_defaults', [1.0], {'a': float}, {'b': 0.0}), + ('bool_defaults', [True], {'a': bool}, {'b': False}), + ('str_defaults', ['str_1'], {'a': str}, {'b': '0'}), + ('tuple_defaults', [(1, 2)], {'a': tuple}, {'b': (0, 0)})] + + +for name, val, annotation in test_cases: + run_generator = generator_test(name, val, annotation) + test_name = list(run_generator.keys())[0] + setattr(TestOverloadList, test_name, run_generator[test_name]) + + +for name, val, annotation, defaults in test_cases_default: + run_generator = generator_test(name, val, annotation, defaults) + test_name = list(run_generator.keys())[0] + setattr(TestOverloadList, test_name, run_generator[test_name]) if __name__ == "__main__": From 716402cbb5486d09827f43d7a237aca10655f8a1 Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Mon, 28 Sep 2020 14:07:54 +0300 Subject: [PATCH 09/12] fix comments and change the old tests --- numba_typing/overload_list.py | 12 +- numba_typing/test_overload_list.py | 288 ++++++++++++++++++----------- 2 files changed, 188 insertions(+), 112 deletions(-) diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py index ec42b02f1..de29d2b6b 100644 --- a/numba_typing/overload_list.py +++ b/numba_typing/overload_list.py @@ -25,7 +25,7 @@ def wrapper(*args): result = choose_func_by_sig(sig_list, values_dict) if result is None: - raise TypeError(f'Unsupported types a={a}, b={b}') + raise TypeError(f'Unsupported types {args}') return result @@ -35,17 +35,19 @@ def wrapper(*args): def valid_signature(list_signature, values_dict, defaults_dict): - def check_defaults(sig_def): + def check_defaults(list_param, sig_def): for name, val in defaults_dict.items(): if sig_def.get(name) is None: raise AttributeError(f'{name} does not match the signature of the function passed to overload_list') - if not sig_def[name] == val: + if sig_def[name] != val: raise ValueError(f'The default arguments are not equal: {name}: {val} != {sig_def[name]}') + if type(sig_def[name]) != list_param[name]: + raise TypeError(f'The default value does not match the type: {list_param[name]}') for sig, _ in list_signature: for param in sig.parameters: - if len(param) != len(values_dict.items()): - check_defaults(sig.defaults) + if len(param) != len(values_dict): + check_defaults(param, sig.defaults) return True diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py index 702f80a3c..dda6676bb 100644 --- a/numba_typing/test_overload_list.py +++ b/numba_typing/test_overload_list.py @@ -4,142 +4,216 @@ import unittest import typing from numba import njit, core +import re T = typing.TypeVar('T') K = typing.TypeVar('K') +S = typing.TypeVar('S', int, float) +UserType = typing.NewType('UserType', int) -class TestOverloadList(unittest.TestCase): - maxDiff = None - - def test_myfunc_literal_type_default(self): - def foo(a, b=0): - ... - - @overload_list.overload_list(foo) - def foo_ovld_list(): - - def foo_int_literal(a: int, b: int = 0): - return ('literal', a, b) - - return (foo_int_literal,) - - @njit - def jit_func(a): - return foo(a, 2) - - self.assertEqual(jit_func(1), ('literal', 1, 2)) - - def test_myfunc_tuple_type_error(self): - def foo(a, b=(0, 0)): - ... - - @overload_list.overload_list(foo) - def foo_ovld_list(): - - def foo_tuple(a: typing.Tuple[int, int], b: tuple = (0, 0)): - return ('tuple_', a, b) - - return (foo_tuple,) - - @njit - def jit_func(a, b): - return foo(a, b) - - self.assertRaises(core.errors.TypingError, jit_func, (1, 2, 3), ('3', False)) - - -def generator_test(func_name, param, values_dict, defaults_dict={}): +def generator_test(param, values_dict, defaults_dict={}): def check_type(typ): if isinstance(typ, type): return typ.__name__ return typ - value_keys = ", ".join("{}".format(key) for key in values_dict.keys()) - defaults_keys = ", ".join("{}".format(key) for key in defaults_dict.keys()) - value_str = ", ".join("{}: {}".format(key, check_type(val)) for key, val in values_dict.items()) - defaults_str = ", ".join("{} = {}".format(key, val) if not isinstance( - val, str) else "{} = '{}'".format(key, val) for key, val in defaults_dict.items()) - defaults_str_type = ", ".join("{}: {} = {}".format(key, check_type(type(val)), val) if not isinstance(val, str) - else "{}: {} = '{}'".format(key, check_type(type(val)), val) - for key, val in defaults_dict.items()) - value_type = ", ".join("{}".format(val) for val in values_dict.values()) - defaults_type = ", ".join("{}".format(type(val)) for val in defaults_dict.values()) - param_qwe = ", ".join("{}".format(i) for i in param) + value_keys = ", ".join(f"{key}" if not key in defaults_dict.keys() + else f"{key} = {defaults_dict[key]}" for key in values_dict.keys()) + value_annotation = ", ".join(f"{key}: {check_type(val)}" if not key in defaults_dict.keys() + else f"{key}: {check_type(val)} = {defaults_dict[key]}" for key, val in values_dict.items()) + value_type = ", ".join(f"{val}" for val in values_dict.values()) + return_value_keys = ", ".join("{}".format(key) for key in values_dict.keys()) + param_func = ", ".join(f"{val}" for val in param) test = f""" -def test_myfunc_{func_name}_type_default(self): - def foo({value_keys},{defaults_str}): +def test_myfunc(): + def foo({value_keys}): ... @overload_list.overload_list(foo) def foo_ovld_list(): - def foo_{func_name}({value_str},{defaults_str_type}): - return ("{value_type}","{defaults_type}") + def foo({value_annotation}): + return ("{value_type}") - return (foo_{func_name},) + return (foo,) @njit - def jit_func({value_keys},{defaults_str}): - return foo({value_keys},{defaults_keys}) + def jit_func({value_keys}): + return foo({return_value_keys}) - self.assertEqual(jit_func({param_qwe}), ("{value_type}", "{defaults_type}")) + return (jit_func({param_func}), ("{value_type}")) """ loc = {} exec(test, globals(), loc) return loc -L = List([1, 2, 3]) -L_int = List([List([1, 2])]) -L_float = List([List([List([3.0, 4.0])])]) -L_f = List([1.0, 2.0]) -D = Dict.empty(key_type=types.unicode_type, value_type=types.int64) -D_1 = Dict.empty(key_type=types.int64, value_type=types.boolean) -D['qwe'] = 1 -D['qaz'] = 2 -D_1[1] = True -D_1[0] = False +list_numba = List([1, 2, 3]) +nested_list_numba = List([List([1, 2])]) +dict_numba = Dict.empty(key_type=types.unicode_type, value_type=types.int64) +dict_numba_1 = Dict.empty(key_type=types.int64, value_type=types.boolean) +dict_numba['qwe'] = 1 +dict_numba_1[1] = True list_type = types.ListType(types.int64) -D_list = Dict.empty(key_type=types.unicode_type, value_type=list_type) -D_list['qwe'] = List([3, 4, 5]) -str_1 = 'qwe' -str_2 = 'qaz' -test_cases = [('int', [1, 2], {'a': int, 'b': int}), ('float', [1.0, 2.0], {'a': float, 'b': float}), - ('bool', [True, True], {'a': bool, 'b': bool}), ('str', ['str_1', 'str_2'], {'a': str, 'b': str}), - ('list', [[1, 2], [3, 4]], {'a': typing.List[int], 'b':list}), - ('List_typed', [L, [3, 4]], {'a': typing.List[int], 'b':list}), - ('tuple', [(1, 2.0), ('3', False)], {'a': typing.Tuple[int, float], 'b':tuple}), - ('dict', ['D', 'D_1'], {'a': typing.Dict[str, int], 'b': typing.Dict[int, bool]}), - ('union_1', [1, False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), - ('union_2', ['str_1', False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), - ('nested_list', ['L_int', 'L_float'], {'a': typing.List[typing.List[int]], - 'b': typing.List[typing.List[typing.List[float]]]}), - ('TypeVar_TT', ['L_f', [3.0, 4.0]], {'a': 'T', 'b': 'T'}), - ('TypeVar_TK', [1.0, 2], {'a': 'T', 'b': 'K'}), - ('TypeVar_ListT_T', ['L', 5], {'a': 'typing.List[T]', 'b': 'T'}), - ('TypeVar_ListT_DictKT', ['L', 'D'], {'a': 'typing.List[T]', 'b': 'typing.Dict[K, T]'}), - ('TypeVar_ListT_DictK_ListT', ['L', 'D_list'], {'a': 'typing.List[T]', - 'b': 'typing.Dict[K, typing.List[T]]'})] - -test_cases_default = [('int_defaults', [1], {'a': int}, {'b': 0}), ('float_defaults', [1.0], {'a': float}, {'b': 0.0}), - ('bool_defaults', [True], {'a': bool}, {'b': False}), - ('str_defaults', ['str_1'], {'a': str}, {'b': '0'}), - ('tuple_defaults', [(1, 2)], {'a': tuple}, {'b': (0, 0)})] - - -for name, val, annotation in test_cases: - run_generator = generator_test(name, val, annotation) - test_name = list(run_generator.keys())[0] - setattr(TestOverloadList, test_name, run_generator[test_name]) - - -for name, val, annotation, defaults in test_cases_default: - run_generator = generator_test(name, val, annotation, defaults) - test_name = list(run_generator.keys())[0] - setattr(TestOverloadList, test_name, run_generator[test_name]) +list_in_dict_numba = Dict.empty(key_type=types.unicode_type, value_type=list_type) +list_in_dict_numba['qwe'] = List([3, 4, 5]) +str_variable = 'qwe' +str_variable_1 = 'qaz' +user_type = UserType(1) + + +def run_test(case): + run_generator = generator_test(*case) + received, expected = run_generator['test_myfunc']() + return (received, expected) + + +def run_test_with_error(case): + run_generator = generator_test(*case) + try: + run_generator['test_myfunc']() + except core.errors.TypingError as err: + res = re.search(r'TypeError', err.msg) + return res.group(0) + + +class TestOverload(unittest.TestCase): + maxDiff = None + + def test_standart_types(self): + test_cases = [([1], {'a': int}), ([1.0], {'a': float}), ([True], {'a': bool}), (['str_variable'], {'a': str})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_container_types(self): + test_cases = [([[1, 2]], {'a': list}), ([(1.0, 2.0)], {'a': tuple}), (['dict_numba'], {'a': dict})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_typing_types(self): + test_cases = [([[1.0, 2.0]], {'a': typing.List[float]}), (['list_numba'], {'a': typing.List[int]}), + ([(1, 2.0)], {'a': typing.Tuple[int, float]}), (['dict_numba_1'], {'a': typing.Dict[int, bool]}), + ([True, 'str_variable'], {'a': typing.Union[bool, str], 'b': typing.Union[bool, str]}), + ([1, False], {'a': typing.Any, 'b': typing.Any})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_nested_typing_types(self): + test_cases = [(['nested_list_numba'], {'a': typing.List[typing.List[int]]}), + ([((1.0,),)], {'a': typing.Tuple[typing.Tuple[float]]})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_typevar_types(self): + test_cases = [([1.0], {'a': 'T'}), ([False], {'a': 'T'}), (['list_numba', [1, 2]], {'a': 'T', 'b': 'T'}), + ([1, 2.0], {'a': 'T', 'b': 'K'}), ([1], {'a': 'S'}), ([1.0], {'a': 'S'}), + ([[True, True]], {'a': 'typing.List[T]'}), (['list_numba'], {'a': 'typing.List[T]'}), + ([('str_variable', 2)], {'a': 'typing.Tuple[T,K]'}), (['dict_numba'], {'a': 'typing.Dict[K, T]'}), + (['dict_numba_1'], {'a': 'typing.Dict[K, T]'}), + (['list_in_dict_numba'], {'a': 'typing.Dict[K, typing.List[T]]'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_only_default_types(self): + test_cases = [([], {'a': int}, {'a': 1}), ([], {'a': float}, {'a': 1.0}), ([], {'a': bool}, {'a': True}), + ([], {'a': str}, {'a': 'str_variable'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_overriding_default_types(self): + test_cases = [([5], {'a': int}, {'a': 1}), ([5.0], {'a': float}, {'a': 1.0}), ([False], {'a': bool}, {'a': True}), + (['str_variable_1'], {'a': str}, {'a': 'str_variable'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_two_types(self): + test_cases = [([5, 3.0], {'a': int, 'b': float}), ([5, 3.0], {'a': int, 'b': float}, {'b': 0.0}), + ([5], {'a': int, 'b': float}, {'b': 0.0}), ([], {'a': int, 'b': float}, {'a': 0, 'b': 0.0})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_three_types(self): + test_cases = [([5, 3.0, 'str_variable_1'], {'a': int, 'b': float, 'c': str}), + ([5, 3.0], {'a': int, 'b': float, 'c': str}, {'c': 'str_variable'}), + ([5], {'a': int, 'b': float, 'c': str}, {'b': 0.0, 'c': 'str_variable'}), + ([], {'a': int, 'b': float, 'c': str}, {'a': 0, 'b': 0.0, 'c': 'str_variable'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_type_error(self): + test_cases = [([1], {'a': float}), ([], {'a': float}, {'a': 1}), ([1], {'a': typing.Iterable[int]}), + ([(1, 2, 3), (1.0, 2.0)], {'a': typing.Tuple[int, int], + 'b':tuple}), ([1, 2.0], {'a': 'T', 'b': 'T'}), + ([1, True], {'a': 'T', 'b': 'S'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(run_test_with_error(case), 'TypeError') + + def test_attribute_error(self): + def foo(a=0): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo(a: int): + return (a,) + + return (foo,) + + @njit + def jit_func(): + return foo() + + try: + jit_func() + except core.errors.TypingError as err: + res = re.search(r'AttributeError', err.msg) + self.assertEqual(res.group(0), 'AttributeError') + + def test_value_error(self): + def foo(a=0): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo(a: int = 1): + return (a,) + + return (foo,) + + @njit + def jit_func(): + return foo() + + try: + jit_func() + except core.errors.TypingError as err: + res = re.search(r'ValueError', err.msg) + self.assertEqual(res.group(0), 'ValueError') if __name__ == "__main__": From 2096e94276735671ca10a3a257634fcad23acb81 Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Mon, 28 Sep 2020 14:16:13 +0300 Subject: [PATCH 10/12] fix PEP --- numba_typing/test_overload_list.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py index dda6676bb..64a0badbd 100644 --- a/numba_typing/test_overload_list.py +++ b/numba_typing/test_overload_list.py @@ -20,10 +20,11 @@ def check_type(typ): return typ.__name__ return typ - value_keys = ", ".join(f"{key}" if not key in defaults_dict.keys() + value_keys = ", ".join(f"{key}" if key not in defaults_dict.keys() else f"{key} = {defaults_dict[key]}" for key in values_dict.keys()) - value_annotation = ", ".join(f"{key}: {check_type(val)}" if not key in defaults_dict.keys() - else f"{key}: {check_type(val)} = {defaults_dict[key]}" for key, val in values_dict.items()) + value_annotation = ", ".join(f"{key}: {check_type(val)}" if key not in defaults_dict.keys() + else f"{key}: {check_type(val)} = {defaults_dict[key]}" + for key, val in values_dict.items()) value_type = ", ".join(f"{val}" for val in values_dict.values()) return_value_keys = ", ".join("{}".format(key) for key in values_dict.keys()) param_func = ", ".join(f"{val}" for val in param) @@ -119,8 +120,8 @@ def test_typevar_types(self): test_cases = [([1.0], {'a': 'T'}), ([False], {'a': 'T'}), (['list_numba', [1, 2]], {'a': 'T', 'b': 'T'}), ([1, 2.0], {'a': 'T', 'b': 'K'}), ([1], {'a': 'S'}), ([1.0], {'a': 'S'}), ([[True, True]], {'a': 'typing.List[T]'}), (['list_numba'], {'a': 'typing.List[T]'}), - ([('str_variable', 2)], {'a': 'typing.Tuple[T,K]'}), (['dict_numba'], {'a': 'typing.Dict[K, T]'}), - (['dict_numba_1'], {'a': 'typing.Dict[K, T]'}), + ([('str_variable', 2)], {'a': 'typing.Tuple[T,K]'}), + (['dict_numba_1'], {'a': 'typing.Dict[K, T]'}), (['dict_numba'], {'a': 'typing.Dict[K, T]'}), (['list_in_dict_numba'], {'a': 'typing.Dict[K, typing.List[T]]'})] for case in test_cases: From 5ec33ac486b75fd9e8cf630c4fd9fb882370972d Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Mon, 28 Sep 2020 14:18:42 +0300 Subject: [PATCH 11/12] fix PEP --- numba_typing/test_overload_list.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py index 64a0badbd..2ab26d126 100644 --- a/numba_typing/test_overload_list.py +++ b/numba_typing/test_overload_list.py @@ -137,8 +137,8 @@ def test_only_default_types(self): self.assertEqual(*run_test(case)) def test_overriding_default_types(self): - test_cases = [([5], {'a': int}, {'a': 1}), ([5.0], {'a': float}, {'a': 1.0}), ([False], {'a': bool}, {'a': True}), - (['str_variable_1'], {'a': str}, {'a': 'str_variable'})] + test_cases = [([5], {'a': int}, {'a': 1}), ([5.0], {'a': float}, {'a': 1.0}), + ([False], {'a': bool}, {'a': True}), (['str_variable_1'], {'a': str}, {'a': 'str_variable'})] for case in test_cases: with self.subTest(case=case): From 7da564b5053f692e7216ef7983bda6203c1f4a80 Mon Sep 17 00:00:00 2001 From: Perevezentsev Vladislav Date: Tue, 29 Sep 2020 17:36:18 +0300 Subject: [PATCH 12/12] add additional tests --- numba_typing/test_overload_list.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py index 2ab26d126..6566ce327 100644 --- a/numba_typing/test_overload_list.py +++ b/numba_typing/test_overload_list.py @@ -92,7 +92,8 @@ def test_standart_types(self): self.assertEqual(*run_test(case)) def test_container_types(self): - test_cases = [([[1, 2]], {'a': list}), ([(1.0, 2.0)], {'a': tuple}), (['dict_numba'], {'a': dict})] + test_cases = [([[1, 2]], {'a': list}), (['list_numba'], {'a': list}), ([(1.0, 2.0)], {'a': tuple}), + ([(1, 2.0)], {'a': tuple}), (['dict_numba'], {'a': dict})] for case in test_cases: with self.subTest(case=case): @@ -101,8 +102,11 @@ def test_container_types(self): def test_typing_types(self): test_cases = [([[1.0, 2.0]], {'a': typing.List[float]}), (['list_numba'], {'a': typing.List[int]}), ([(1, 2.0)], {'a': typing.Tuple[int, float]}), (['dict_numba_1'], {'a': typing.Dict[int, bool]}), + ([False], {'a': typing.Union[bool, str]}), (['str_variable'], {'a': typing.Union[bool, str]}), ([True, 'str_variable'], {'a': typing.Union[bool, str], 'b': typing.Union[bool, str]}), - ([1, False], {'a': typing.Any, 'b': typing.Any})] + (['str_variable', True], {'a': typing.Union[bool, str], 'b': typing.Union[bool, str]}), + ([1, False], {'a': typing.Any, 'b': typing.Any}), + ([1.0, 'str_variable'], {'a': typing.Any, 'b': typing.Any})] for case in test_cases: with self.subTest(case=case): @@ -117,7 +121,9 @@ def test_nested_typing_types(self): self.assertEqual(*run_test(case)) def test_typevar_types(self): - test_cases = [([1.0], {'a': 'T'}), ([False], {'a': 'T'}), (['list_numba', [1, 2]], {'a': 'T', 'b': 'T'}), + test_cases = [([1.0], {'a': 'T'}), ([False], {'a': 'T'}), ([1, 2], {'a': 'T', 'b': 'T'}), + ([1.0, 2.0], {'a': 'T', 'b': 'T'}), (['str_variable', 'str_variable'], {'a': 'T', 'b': 'T'}), + (['list_numba', [1, 2]], {'a': 'T', 'b': 'T'}), ([1, 2.0], {'a': 'T', 'b': 'K'}), ([1], {'a': 'S'}), ([1.0], {'a': 'S'}), ([[True, True]], {'a': 'typing.List[T]'}), (['list_numba'], {'a': 'typing.List[T]'}), ([('str_variable', 2)], {'a': 'typing.Tuple[T,K]'}), @@ -164,9 +170,10 @@ def test_three_types(self): def test_type_error(self): test_cases = [([1], {'a': float}), ([], {'a': float}, {'a': 1}), ([1], {'a': typing.Iterable[int]}), - ([(1, 2, 3), (1.0, 2.0)], {'a': typing.Tuple[int, int], - 'b':tuple}), ([1, 2.0], {'a': 'T', 'b': 'T'}), - ([1, True], {'a': 'T', 'b': 'S'})] + ([(1, 2, 3)], {'a': typing.Tuple[int, int]}), ([(1.0, 2)], {'a': typing.Tuple[int, int]}), + ([(1, 2.0)], {'a': typing.Tuple[int, int]}), ([(1.0, 2.0)], {'a': typing.Tuple[int, int]}), + ([1, 2.0], {'a': 'T', 'b': 'T'}), ([(1, 2), (1, 2.0)], {'a': 'T', 'b': 'T'}), + ([True], {'a': 'S'}), (['str_variable'], {'a': 'S'})] for case in test_cases: with self.subTest(case=case):