Skip to content

Allow omitting redundant Generic[T] in base classes #2811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None:
Note that this is performed *before* semantic analysis.
"""
removed = [] # type: List[int]
declared_tvars = [] # type: List[Tuple[str, TypeVarExpr]]
type_vars = [] # type: List[TypeVarDef]
for i, base_expr in enumerate(defn.base_type_exprs):
try:
Expand All @@ -691,12 +692,25 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None:
continue
tvars = self.analyze_typevar_declaration(base)
if tvars is not None:
if type_vars:
if declared_tvars:
self.fail('Duplicate Generic in bases', defn)
removed.append(i)
for j, (name, tvar_expr) in enumerate(tvars):
type_vars.append(TypeVarDef(name, j + 1, tvar_expr.values,
tvar_expr.upper_bound, tvar_expr.variance))
declared_tvars.extend(tvars)

all_tvars = self.get_all_bases_tvars(defn, removed)
if declared_tvars:
if len(self.remove_dups(declared_tvars)) < len(declared_tvars):
self.fail("Duplicate type variables in Generic[...]", defn)
declared_tvars = self.remove_dups(declared_tvars)
if not set(all_tvars).issubset(set(declared_tvars)):
self.fail("If Generic[...] is present it should list all type variables", defn)
# In case of error, Generic tvars will go first
declared_tvars = self.remove_dups(declared_tvars + all_tvars)
else:
declared_tvars = all_tvars
for j, (name, tvar_expr) in enumerate(declared_tvars):
type_vars.append(TypeVarDef(name, j + 1, tvar_expr.values,
tvar_expr.upper_bound, tvar_expr.variance))
if type_vars:
defn.type_vars = type_vars
if defn.info:
Expand Down Expand Up @@ -733,6 +747,45 @@ def analyze_unbound_tvar(self, t: Type) -> Tuple[str, TypeVarExpr]:
return unbound.name, sym.node
return None

def get_all_bases_tvars(self, defn: ClassDef,
removed: List[int]) -> List[Tuple[str, TypeVarExpr]]:
tvars = [] # type: List[Tuple[str, TypeVarExpr]]
for i, base_expr in enumerate(defn.base_type_exprs):
if i not in removed:
try:
base = expr_to_unanalyzed_type(base_expr)
except TypeTranslationError:
# This error will be caught later.
continue
tvars.extend(self.get_tvars(base))
return self.remove_dups(tvars)

def get_tvars(self, tp: Type) -> List[Tuple[str, TypeVarExpr]]:
tvars = [] # type: List[Tuple[str, TypeVarExpr]]
if isinstance(tp, UnboundType):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unanalyzed types can also contain TypeList instances and these should be handled here as well. For example, class A(B[Callable[[T], S]]): ... would have T inside a TypeList.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I forgot about TypeList, fixed this in a new commit.

tp_args = tp.args
elif isinstance(tp, TypeList):
tp_args = tp.items
else:
return tvars
for arg in tp_args:
tvar = self.analyze_unbound_tvar(arg)
if tvar:
tvars.append(tvar)
else:
tvars.extend(self.get_tvars(arg))
return self.remove_dups(tvars)

def remove_dups(self, tvars: List[T]) -> List[T]:
# Get unique elements in order of appearance
all_tvars = set(tvars)
new_tvars = [] # type: List[T]
for t in tvars:
if t in all_tvars:
new_tvars.append(t)
all_tvars.remove(t)
return new_tvars

def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
# special case for NamedTuple
for base_expr in defn.base_type_exprs:
Expand Down
125 changes: 120 additions & 5 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ class C(Generic[T, S]):
def __init__(self, x: T, y: S) -> None:
...

class D(C[int, T], Generic[T]): ...
class D(C[int, T]): ...

D[str](1, 'a')
D[str](1, 1) # E: Argument 2 to "D" has incompatible type "int"; expected "str"
Expand Down Expand Up @@ -705,9 +705,9 @@ class Node(Generic[T]):

TupledNode = Node[Tuple[T, T]]

class D(Generic[T], TupledNode[T]):
class D(TupledNode[T]):
...
class L(Generic[T], List[TupledNode[T]]):
class L(List[TupledNode[T]]):
...

def f_bad(x: T) -> D[T]:
Expand Down Expand Up @@ -738,9 +738,10 @@ TupledNode = Node[Tuple[T, T]]
UNode = Union[int, Node[T]]

class C(TupledNode): ... # Same as TupledNode[Any]
class D(TupledNode[T]): ... # E: Invalid type "__main__.T"
class D(TupledNode[T]): ...
class E(Generic[T], UNode[T]): ... # E: Invalid base class

reveal_type(D((1, 1))) # E: Revealed type is '__main__.D[builtins.int*]'
[builtins fixtures/list.pyi]

[case testGenericTypeAliasesUnion]
Expand Down Expand Up @@ -977,6 +978,120 @@ reveal_type(Bad) # E: Revealed type is 'Any'
[out]


-- Simplified declaration of generics
-- ----------------------------------

[case testSimplifiedGenericSimple]
from typing import TypeVar, Generic
T = TypeVar('T')
S = TypeVar('S')
class B(Generic[T]):
def b(self) -> T: ...

class C(Generic[T]):
def c(self) -> T: ...

class D(B[T], C[S]): ...

reveal_type(D[str, int]().b()) # E: Revealed type is 'builtins.str*'
reveal_type(D[str, int]().c()) # E: Revealed type is 'builtins.int*'
[builtins fixtures/list.pyi]
[out]

[case testSimplifiedGenericCallable]
from typing import TypeVar, Generic, Callable
T = TypeVar('T')
S = TypeVar('S')
class B(Generic[T]):
def b(self) -> T: ...

class D(B[Callable[[T], S]]): ...

reveal_type(D[str, int]().b()) # E: Revealed type is 'def (builtins.str*) -> builtins.int*'
[builtins fixtures/list.pyi]
[out]

[case testSimplifiedGenericComplex]
from typing import TypeVar, Generic, Tuple
T = TypeVar('T')
S = TypeVar('S')
U = TypeVar('U')

class A(Generic[T, S]):
pass

class B(Generic[T, S]):
def m(self) -> Tuple[T, S]:
pass

class C(A[S, B[T, int]], B[U, A[int, T]]):
pass

c = C[object, int, str]()
reveal_type(c.m()) # E: Revealed type is 'Tuple[builtins.str*, __main__.A*[builtins.int, builtins.int*]]'
[builtins fixtures/tuple.pyi]
[out]


[case testSimplifiedGenericOrder]
from typing import TypeVar, Generic
T = TypeVar('T')
S = TypeVar('S')

class B(Generic[T]):
def b(self) -> T: ...

class C(Generic[T]):
def c(self) -> T: ...

class D(B[T], C[S], Generic[S, T]): ...

reveal_type(D[str, int]().b()) # E: Revealed type is 'builtins.int*'
reveal_type(D[str, int]().c()) # E: Revealed type is 'builtins.str*'
[builtins fixtures/list.pyi]
[out]

[case testSimplifiedGenericDuplicate]
from typing import TypeVar, Generic
T = TypeVar('T')

class A(Generic[T, T]): # E: Duplicate type variables in Generic[...]
pass

a = A[int]()
[builtins fixtures/list.pyi]
[out]

[case testSimplifiedGenericNotAll]
from typing import TypeVar, Generic
T = TypeVar('T')
S = TypeVar('S')

class A(Generic[T]):
pass
class B(Generic[T]):
pass

class C(A[T], B[S], Generic[T]): # E: If Generic[...] is present it should list all type variables
pass

c = C[int, str]()
[builtins fixtures/list.pyi]
[out]

[case testSimplifiedGenericInvalid]
from typing import TypeVar, Generic
T = TypeVar('T')

class A(Generic[T]):
pass

class B(A[S]): # E: Name 'S' is not defined
pass
[builtins fixtures/list.pyi]
[out]


-- Multiple assignment with lists
-- ------------------------------

Expand Down Expand Up @@ -1025,7 +1140,7 @@ class A: pass
from typing import TypeVar, Generic, Iterable
T = TypeVar('T')
class object: pass
class list(Iterable[T], Generic[T]):
class list(Iterable[T]):
def __setitem__(self, x: int, v: T) -> None: pass
class int: pass
class type: pass
Expand Down