Skip to content

Allow returning inferred None from functions #5663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, ClassDef, Block, SymbolTable,
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, ClassDef, Block, SymbolNode,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, LITERAL_TYPE, REVEAL_TYPE
)
from mypy.literals import literal
Expand Down Expand Up @@ -329,11 +329,48 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
self.check_protocol_issubclass(e)
if isinstance(ret_type, UninhabitedType) and not ret_type.ambiguous:
self.chk.binder.unreachable()
if not allow_none_return and isinstance(ret_type, NoneTyp):
if not allow_none_return and self.always_returns_none(e.callee):
self.chk.msg.does_not_return_value(callee_type, e)
return AnyType(TypeOfAny.from_error)
return ret_type

def always_returns_none(self, node: Expression) -> bool:
"""Check if `node` refers to something explicitly annotated as only returning None."""
if isinstance(node, RefExpr):
if self.defn_returns_none(node.node):
return True
if isinstance(node, MemberExpr) and node.node is None: # instance or class attribute
typ = self.chk.type_map.get(node.expr)
if isinstance(typ, Instance):
info = typ.type
elif (isinstance(typ, CallableType) and typ.is_type_obj() and
isinstance(typ.ret_type, Instance)):
info = typ.ret_type.type
else:
return False
sym = info.get(node.name)
if sym and self.defn_returns_none(sym.node):
return True
return False

def defn_returns_none(self, defn: Optional[SymbolNode]) -> bool:
"""Check if `defn` can _only_ return None."""
if isinstance(defn, FuncDef):
return (isinstance(defn.type, CallableType) and
isinstance(defn.type.ret_type, NoneTyp))
if isinstance(defn, OverloadedFuncDef):
return all(isinstance(item.type, CallableType) and
isinstance(item.type.ret_type, NoneTyp) for item in defn.items)
if isinstance(defn, Var):
if (not defn.is_inferred and isinstance(defn.type, CallableType) and
isinstance(defn.type.ret_type, NoneTyp)):
return True
if isinstance(defn.type, Instance):
sym = defn.type.type.get('__call__')
if sym and self.defn_returns_none(sym.node):
return True
return False

def check_runtime_protocol_test(self, e: CallExpr) -> None:
for expr in mypy.checker.flatten(e.args[1]):
tp = self.chk.type_map[expr]
Expand Down Expand Up @@ -3171,6 +3208,7 @@ def visit_yield_from_expr(self, e: YieldFromExpr, allow_none_return: bool = Fals
if isinstance(actual_item_type, AnyType):
expr_type = AnyType(TypeOfAny.from_another_any, source_any=actual_item_type)
else:
# Treat `Iterator[X]` as a shorthand for `Generator[X, None, Any]`.
expr_type = NoneTyp()

if not allow_none_return and isinstance(expr_type, NoneTyp):
Expand Down
10 changes: 5 additions & 5 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,7 +2017,7 @@ def analyze_name_lvalue(self,
not self.type)
if (add_global or nested_global) and lval.name not in self.globals:
# Define new global name.
v = self.make_name_lvalue_var(lval, GDEF)
v = self.make_name_lvalue_var(lval, GDEF, not explicit_type)
self.globals[lval.name] = SymbolTableNode(GDEF, v)
elif isinstance(lval.node, Var) and lval.is_new_def:
if lval.kind == GDEF:
Expand All @@ -2029,7 +2029,7 @@ def analyze_name_lvalue(self,
lval.name not in self.global_decls[-1] and
lval.name not in self.nonlocal_decls[-1]):
# Define new local name.
v = self.make_name_lvalue_var(lval, LDEF)
v = self.make_name_lvalue_var(lval, LDEF, not explicit_type)
self.add_local(v, lval)
if lval.name == '_':
# Special case for assignment to local named '_': always infer 'Any'.
Expand All @@ -2038,16 +2038,16 @@ def analyze_name_lvalue(self,
elif not self.is_func_scope() and (self.type and
lval.name not in self.type.names):
# Define a new attribute within class body.
v = self.make_name_lvalue_var(lval, MDEF)
v.is_inferred = not explicit_type
v = self.make_name_lvalue_var(lval, MDEF, not explicit_type)
self.type.names[lval.name] = SymbolTableNode(MDEF, v)
else:
self.make_name_lvalue_point_to_existing_def(lval, explicit_type, final_cb)

def make_name_lvalue_var(self, lvalue: NameExpr, kind: int) -> Var:
def make_name_lvalue_var(self, lvalue: NameExpr, kind: int, inferred: bool) -> Var:
"""Return a Var node for an lvalue that is a name expression."""
v = Var(lvalue.name)
v.set_line(lvalue)
v.is_inferred = inferred
if kind == MDEF:
assert self.type is not None
v.info = self.type
Expand Down
6 changes: 4 additions & 2 deletions test-data/unit/check-bound.test
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class C(Generic[T]):
return self.t
c1 = None # type: C[None]
c1.get()
d = c1.get() # E: "get" of "C" does not return a value
d = c1.get()
reveal_type(d) # E: Revealed type is 'None'


[case testBoundAny]
Expand All @@ -82,7 +83,8 @@ def f(g: Callable[[], T]) -> T:
return g()
def h() -> None: pass
f(h)
a = f(h) # E: "f" does not return a value
a = f(h)
reveal_type(a) # E: Revealed type is 'None'


[case testBoundInheritance]
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,27 @@ class A:
reveal_type(cls()) # E: Revealed type is 'T`-1'
return cls()
[builtins fixtures/classmethod.pyi]

[case testNoComplainFieldNone]
# flags: --python-version 3.6
# flags: --no-strict-optional
from dataclasses import dataclass, field
from typing import Optional

@dataclass
class Foo:
bar: Optional[int] = field(default=None)
[builtins fixtures/list.pyi]
[out]

[case testNoComplainFieldNoneStrict]
# flags: --python-version 3.6
# flags: --strict-optional
from dataclasses import dataclass, field
from typing import Optional

@dataclass
class Foo:
bar: Optional[int] = field(default=None)
[builtins fixtures/list.pyi]
[out]
1 change: 1 addition & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,7 @@ a, o = None, None # type: (A, object)
a = f() # E: "f" does not return a value
o = a() # E: Function does not return a value
o = A().g(a) # E: "g" of "A" does not return a value
o = A.g(a, a) # E: "g" of "A" does not return a value
A().g(f()) # E: "f" does not return a value
x: A = f() # E: "f" does not return a value
f()
Expand Down
61 changes: 59 additions & 2 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2246,8 +2246,7 @@ def h() -> Dict[Union[str, int], str]:

def i() -> List[Union[int, float]]:
x: List[int] = [1]
return x # E: Incompatible return value type (got "List[int]", expected "List[Union[int, float]]") \
# N: Perhaps you need a type annotation for "x"? Suggestion: "List[Union[int, float]]"
return x # E: Incompatible return value type (got "List[int]", expected "List[Union[int, float]]")

[builtins fixtures/dict.pyi]

Expand Down Expand Up @@ -2319,3 +2318,61 @@ main:6: error: b: builtins.int
main:6: error: c: builtins.int
main:9: error: Revealed local types are:
main:9: error: a: builtins.float

[case testNoComplainOverloadNone]
# flags: --no-strict-optional
from typing import overload, Optional
@overload
def bar(x: None) -> None:
...
@overload
def bar(x: int) -> str:
...
def bar(x: Optional[int]) -> Optional[str]:
if x is None:
return None
return "number"

reveal_type(bar(None)) # E: Revealed type is 'None'
[builtins fixtures/isinstance.pyi]
[out]

[case testNoComplainOverloadNoneStrict]
# flags: --strict-optional
from typing import overload, Optional
@overload
def bar(x: None) -> None:
...
@overload
def bar(x: int) -> str:
...
def bar(x: Optional[int]) -> Optional[str]:
if x is None:
return None
return "number"

reveal_type(bar(None)) # E: Revealed type is 'None'
[builtins fixtures/isinstance.pyi]
[out]

[case testNoComplainInferredNone]
# flags: --no-strict-optional
from typing import TypeVar, Optional
T = TypeVar('T')
def X(val: T) -> T: ...
x_in = None
def Y(x: Optional[str] = X(x_in)): ...

xx: Optional[int] = X(x_in)
[out]

[case testNoComplainInferredNoneStrict]
# flags: --strict-optional
from typing import TypeVar, Optional
T = TypeVar('T')
def X(val: T) -> T: ...
x_in = None
def Y(x: Optional[str] = X(x_in)): ...

xx: Optional[int] = X(x_in)
[out]
3 changes: 2 additions & 1 deletion test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ reveal_type(c.f(None)) # E: Revealed type is 'builtins.list[builtins.int]'
[builtins fixtures/list.pyi]

[case testGenericMethodCalledInGenericContext]
# flags: --strict-optional
from typing import TypeVar, Generic

_KT = TypeVar('_KT')
Expand All @@ -897,7 +898,7 @@ class M(Generic[_KT, _VT]):
def get(self, k: _KT, default: _T) -> _T: ...

def f(d: M[_KT, _VT], k: _KT) -> _VT:
return d.get(k, None) # E: "get" of "M" does not return a value
return d.get(k, None) # E: Incompatible return value type (got "None", expected "_VT")

[case testGenericMethodCalledInGenericContext2]
from typing import TypeVar, Generic, Union
Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/check-optional.test
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ reveal_type(z2) # E: Revealed type is 'Union[builtins.int, builtins.str, None]'

[case testLambdaReturningNone]
f = lambda: None
x = f() # E: Function does not return a value
x = f()
reveal_type(x) # E: Revealed type is 'None'

[case testNoneArgumentType]
def f(x: None) -> None: pass
Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,8 @@ def g(x: U, y: V) -> None:
# N: Possible overload variants: \
# N: def [T <: str] f(x: T) -> T \
# N: def [T <: str] f(x: List[T]) -> None
a = f([x]) # E: "f" does not return a value
a = f([x])
reveal_type(a) # E: Revealed type is 'None'
f([y]) # E: Value of type variable "T" of "f" cannot be "V"
f([x, y]) # E: Value of type variable "T" of "f" cannot be "object"
[builtins fixtures/list.pyi]
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -1238,8 +1238,8 @@ class P2(Protocol):
T = TypeVar('T')
def f(x: Callable[[T, T], None]) -> T: pass
def g(x: P, y: P2) -> None: pass
x = f(g) # E: "f" does not return a value

x = f(g)
reveal_type(x) # E: Revealed type is 'None'
[case testMeetProtocolWithNormal]
from typing import Protocol, Callable, TypeVar

Expand Down