@@ -1782,7 +1782,7 @@ def visit_try_stmt(self, s: TryStmt) -> Type:
1782
1782
for i in range (len (s .handlers )):
1783
1783
self .binder .push_frame ()
1784
1784
if s .types [i ]:
1785
- t = self .exception_type (s .types [i ])
1785
+ t = self .visit_except_handler_test (s .types [i ])
1786
1786
if s .vars [i ]:
1787
1787
# To support local variables, we make this a definition line,
1788
1788
# causing assignment to set the variable's type.
@@ -1822,38 +1822,30 @@ def visit_try_stmt(self, s: TryStmt) -> Type:
1822
1822
if s .finally_body :
1823
1823
self .accept (s .finally_body )
1824
1824
1825
- def exception_type (self , n : Node ) -> Type :
1826
- if isinstance (n , TupleExpr ):
1827
- t = None # type: Type
1828
- for item in n .items :
1829
- tt = self .exception_type (item )
1830
- if t :
1831
- t = join_types (t , tt )
1832
- else :
1833
- t = tt
1834
- return t
1835
- else :
1836
- # A single exception type; should evaluate to a type object type.
1837
- type = self .accept (n )
1838
- return self .check_exception_type (type , n )
1839
- self .fail ('Unsupported exception' , n )
1840
- return AnyType ()
1825
+ def visit_except_handler_test (self , n : Node ) -> Type :
1826
+ """Type check an exception handler test clause."""
1827
+ type = self .accept (n )
1828
+ if isinstance (type , AnyType ):
1829
+ return type
1841
1830
1842
- def check_exception_type (self , type : Type , context : Context ) -> Type :
1843
- if isinstance (type , FunctionLike ):
1844
- item = type .items ()[0 ]
1845
- ret = item .ret_type
1846
- if (is_subtype (ret , self .named_type ('builtins.BaseException' ))
1831
+ all_types = [] # type: List[Type]
1832
+ test_types = type .items if isinstance (type , TupleType ) else [type ]
1833
+
1834
+ for ttype in test_types :
1835
+ if not isinstance (ttype , FunctionLike ):
1836
+ self .fail (messages .INVALID_EXCEPTION_TYPE , n )
1837
+ return AnyType ()
1838
+
1839
+ item = ttype .items ()[0 ]
1840
+ ret_type = item .ret_type
1841
+ if not (is_subtype (ret_type , self .named_type ('builtins.BaseException' ))
1847
1842
and item .is_type_obj ()):
1848
- return ret
1849
- else :
1850
- self .fail (messages .INVALID_EXCEPTION_TYPE , context )
1843
+ self .fail (messages .INVALID_EXCEPTION_TYPE , n )
1851
1844
return AnyType ()
1852
- elif isinstance (type , AnyType ):
1853
- return AnyType ()
1854
- else :
1855
- self .fail (messages .INVALID_EXCEPTION_TYPE , context )
1856
- return AnyType ()
1845
+
1846
+ all_types .append (ret_type )
1847
+
1848
+ return UnionType .make_simplified_union (all_types )
1857
1849
1858
1850
def visit_for_stmt (self , s : ForStmt ) -> Type :
1859
1851
"""Type check a for statement."""
0 commit comments