Skip to content

Commit e1a934b

Browse files
committed
Store GadtConstraint in CaseDef instead of peepholing
1 parent 0d990e8 commit e1a934b

File tree

16 files changed

+137
-103
lines changed

16 files changed

+137
-103
lines changed

compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1852,7 +1852,7 @@ class JSCodeGen()(using genCtx: Context) {
18521852
js.Block(genStatsAndExpr)
18531853

18541854
case GadtExpr(_, expr) =>
1855-
genExpr(expr)
1855+
genStatOrExpr(expr, isStat)
18561856

18571857
case Typed(expr, _) =>
18581858
expr match {

compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,24 @@ class TreeTypeMap(
125125
transformInlined(inlined)
126126
case GadtExpr(gadt, expr) =>
127127
cpy.GadtExpr(expr)(gadt, transform(expr))
128-
case cdef @ CaseDef(pat, guard, expr @ GadtExpr(gadt, rhs)) =>
129-
val patVars1 = patVars(pat)
130-
val tmap = withMappedSyms(patVars1 ::: gadt.symbols.diff(substFrom).diff(patVars1))
131-
val gadt1 = tmap.rebuild(gadt)
132-
inContext(ctx.withGadt(gadt1)) {
133-
val pat1 = tmap.transform(pat)
134-
val guard1 = tmap.transform(guard)
135-
val rhs1 = cpy.GadtExpr(expr)(gadt1, tmap.transform(rhs))
136-
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
137-
}
138128
case cdef @ CaseDef(pat, guard, rhs) =>
139-
val tmap = withMappedSyms(patVars(pat))
140-
val pat1 = tmap.transform(pat)
141-
val guard1 = tmap.transform(guard)
142-
val rhs1 = tmap.transform(rhs)
143-
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
129+
val patVars1 = patVars(pat)
130+
cdef.gadt match
131+
case EmptyGadtConstraint =>
132+
val tmap = withMappedSyms(patVars1)
133+
val pat1 = tmap.transform(pat)
134+
val guard1 = tmap.transform(guard)
135+
val rhs1 = tmap.transform(rhs)
136+
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
137+
case _ =>
138+
val tmap = withMappedSyms(patVars1 ::: cdef.gadt.symbols.diff(substFrom).diff(patVars1))
139+
val gadt1 = tmap.rebuild(cdef.gadt)
140+
inContext(ctx.withGadt(gadt1)) {
141+
val pat1 = tmap.transform(pat)
142+
val guard1 = tmap.transform(guard)
143+
val rhs1 = tmap.transform(rhs)
144+
cpy.CaseDef(cdef)(pat1, guard1, rhs1, gadt1)
145+
}
144146
case labeled @ Labeled(bind, expr) =>
145147
val tmap = withMappedSyms(bind.symbol :: Nil)
146148
val bind1 = tmap.transformSub(bind)

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ object Trees {
614614
}
615615

616616
/** case pat if guard => body */
617-
case class CaseDef[-T >: Untyped] private[ast] (pat: Tree[T], guard: Tree[T], body: Tree[T])(implicit @constructorOnly src: SourceFile)
617+
case class CaseDef[-T >: Untyped] private[ast] (pat: Tree[T], guard: Tree[T], body: Tree[T])(val gadt: GadtConstraint)(implicit @constructorOnly src: SourceFile)
618618
extends Tree[T] {
619619
type ThisTree[-T >: Untyped] = CaseDef[T]
620620
}
@@ -1233,9 +1233,9 @@ object Trees {
12331233
case tree: InlineMatch => finalize(tree, untpd.InlineMatch(selector, cases)(sourceFile(tree)))
12341234
case _ => finalize(tree, untpd.Match(selector, cases)(sourceFile(tree)))
12351235
}
1236-
def CaseDef(tree: Tree)(pat: Tree, guard: Tree, body: Tree)(using Context): CaseDef = tree match {
1236+
def CaseDef(tree: Tree)(pat: Tree, guard: Tree, body: Tree, gadt: GadtConstraint)(using Context): CaseDef = tree match {
12371237
case tree: CaseDef if (pat eq tree.pat) && (guard eq tree.guard) && (body eq tree.body) => tree
1238-
case _ => finalize(tree, untpd.CaseDef(pat, guard, body)(sourceFile(tree)))
1238+
case _ => finalize(tree, untpd.CaseDef(pat, guard, body, gadt)(sourceFile(tree)))
12391239
}
12401240
def Labeled(tree: Tree)(bind: Bind, expr: Tree)(using Context): Labeled = tree match {
12411241
case tree: Labeled if (bind eq tree.bind) && (expr eq tree.expr) => tree
@@ -1355,8 +1355,8 @@ object Trees {
13551355
If(tree: Tree)(cond, thenp, elsep)
13561356
def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
13571357
Closure(tree: Tree)(env, meth, tpt)
1358-
def CaseDef(tree: CaseDef)(pat: Tree = tree.pat, guard: Tree = tree.guard, body: Tree = tree.body)(using Context): CaseDef =
1359-
CaseDef(tree: Tree)(pat, guard, body)
1358+
def CaseDef(tree: CaseDef)(pat: Tree = tree.pat, guard: Tree = tree.guard, body: Tree = tree.body, gadt: GadtConstraint = tree.gadt)(using Context): CaseDef =
1359+
CaseDef(tree: Tree)(pat, guard, body, gadt)
13601360
def Try(tree: Try)(expr: Tree = tree.expr, cases: List[CaseDef] = tree.cases, finalizer: Tree = tree.finalizer)(using Context): Try =
13611361
Try(tree: Tree)(expr, cases, finalizer)
13621362
def UnApply(tree: UnApply)(fun: Tree = tree.fun, implicits: List[Tree] = tree.implicits, patterns: List[Tree] = tree.patterns)(using Context): UnApply =
@@ -1442,10 +1442,8 @@ object Trees {
14421442
cpy.Match(tree)(transform(selector), transformSub(cases))
14431443
case GadtExpr(gadt, expr) =>
14441444
inContext(ctx.withGadt(gadt))(cpy.GadtExpr(tree)(gadt, transform(expr)))
1445-
case CaseDef(pat, guard, body @ GadtExpr(gadt, _)) =>
1446-
inContext(ctx.withGadt(gadt))(cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body)))
1447-
case CaseDef(pat, guard, body) =>
1448-
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
1445+
case cdef @ CaseDef(pat, guard, body) =>
1446+
inContext(ctx.withGadt(cdef.gadt))(cpy.CaseDef(cdef)(transform(pat), transform(guard), transform(body)))
14491447
case Labeled(bind, expr) =>
14501448
cpy.Labeled(tree)(transformSub(bind), transform(expr))
14511449
case Return(expr, from) =>
@@ -1582,10 +1580,8 @@ object Trees {
15821580
this(this(x, selector), cases)
15831581
case GadtExpr(gadt, expr) =>
15841582
inContext(ctx.withGadt(gadt))(this(x, expr))
1585-
case CaseDef(pat, guard, body @ GadtExpr(gadt, _)) =>
1586-
inContext(ctx.withGadt(gadt))(this(this(this(x, pat), guard), body))
1587-
case CaseDef(pat, guard, body) =>
1588-
this(this(this(x, pat), guard), body)
1583+
case cdef @ CaseDef(pat, guard, body) =>
1584+
inContext(ctx.withGadt(cdef.gadt))(this(this(this(x, pat), guard), body))
15891585
case Labeled(bind, expr) =>
15901586
this(this(x, bind), expr)
15911587
case Return(expr, from) =>

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
131131
Closure(meth, tss => rhsFn(tss.head).changeOwner(ctx.owner, meth))
132132
}
133133

134-
def CaseDef(pat: Tree, guard: Tree, body: Tree)(using Context): CaseDef =
135-
ta.assignType(untpd.CaseDef(pat, guard, body), pat, body)
134+
def CaseDef(pat: Tree, guard: Tree, body: Tree, gadt: GadtConstraint = EmptyGadtConstraint)(using Context): CaseDef =
135+
ta.assignType(untpd.CaseDef(pat, guard, body, gadt), pat, body)
136136

137137
def Match(selector: Tree, cases: List[CaseDef])(using Context): Match =
138138
ta.assignType(untpd.Match(selector, cases), selector, cases)
@@ -713,8 +713,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
713713
}
714714
}
715715

716-
override def CaseDef(tree: Tree)(pat: Tree, guard: Tree, body: Tree)(using Context): CaseDef = {
717-
val tree1 = untpdCpy.CaseDef(tree)(pat, guard, body)
716+
override def CaseDef(tree: Tree)(pat: Tree, guard: Tree, body: Tree, gadt: GadtConstraint)(using Context): CaseDef = {
717+
val tree1 = untpdCpy.CaseDef(tree)(pat, guard, body, gadt)
718718
tree match {
719719
case tree: CaseDef if body.tpe eq tree.body.tpe => tree1.withTypeUnchecked(tree.tpe)
720720
case _ => ta.assignType(tree1, pat, body)
@@ -770,8 +770,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
770770
If(tree: Tree)(cond, thenp, elsep)
771771
override def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
772772
Closure(tree: Tree)(env, meth, tpt)
773-
override def CaseDef(tree: CaseDef)(pat: Tree = tree.pat, guard: Tree = tree.guard, body: Tree = tree.body)(using Context): CaseDef =
774-
CaseDef(tree: Tree)(pat, guard, body)
773+
override def CaseDef(tree: CaseDef)(pat: Tree = tree.pat, guard: Tree = tree.guard, body: Tree = tree.body, gadt: GadtConstraint = tree.gadt)(using Context): CaseDef =
774+
CaseDef(tree: Tree)(pat, guard, body, gadt)
775775
override def Try(tree: Try)(expr: Tree = tree.expr, cases: List[CaseDef] = tree.cases, finalizer: Tree = tree.finalizer)(using Context): Try =
776776
Try(tree: Tree)(expr, cases, finalizer)
777777
}
@@ -1305,7 +1305,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
13051305
else if (tree.tpe.widen isRef numericCls)
13061306
tree
13071307
else {
1308-
report.warning(i"conversion from ${tree.tpe.widen} to ${numericCls.typeRef} will always fail at runtime.")
1308+
report.warning(i"conversion from ${tree.tpe.widen} to ${numericCls.typeRef} will always fail at runtime.", tree.srcPos)
13091309
Throw(New(defn.ClassCastExceptionClass.typeRef, Nil)).withSpan(tree.span)
13101310
}
13111311
}

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
388388
def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt)
389389
def Match(selector: Tree, cases: List[CaseDef])(implicit src: SourceFile): Match = new Match(selector, cases)
390390
def InlineMatch(selector: Tree, cases: List[CaseDef])(implicit src: SourceFile): Match = new InlineMatch(selector, cases)
391-
def CaseDef(pat: Tree, guard: Tree, body: Tree)(implicit src: SourceFile): CaseDef = new CaseDef(pat, guard, body)
391+
def CaseDef(pat: Tree, guard: Tree, body: Tree, gadt: GadtConstraint = EmptyGadtConstraint)(implicit src: SourceFile): CaseDef = new CaseDef(pat, guard, body)(gadt)
392392
def Labeled(bind: Bind, expr: Tree)(implicit src: SourceFile): Labeled = new Labeled(bind, expr)
393393
def Return(expr: Tree, from: Tree)(implicit src: SourceFile): Return = new Return(expr, from)
394394
def WhileDo(cond: Tree, body: Tree)(implicit src: SourceFile): WhileDo = new WhileDo(cond, body)

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ final class ProperGadtConstraint private(
305305
&& mapping == that.mapping
306306
&& reverseMapping == that.reverseMapping
307307
&& wasConstrained == that.wasConstrained
308-
case _ => false
308+
case _ => mapping.isEmpty
309309

310310
override def toText(printer: Printer): Texts.Text = printer.toText(this)
311311
}

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -472,21 +472,6 @@ class TreePickler(pickler: TastyPickler) {
472472
writeByte(BLOCK)
473473
stats.foreach(preRegister)
474474
withLength { pickleTree(expr); stats.foreach(pickleTree) }
475-
case GadtExpr(gadt, expr) =>
476-
writeByte(GADTEXPR)
477-
withLength {
478-
for (symbols, nestingLevel) <- gadt.inputs do
479-
writeByte(CONSTRAINT)
480-
withLength {
481-
writeInt(nestingLevel)
482-
for sym <- symbols do
483-
val TypeBounds(lo, hi) = gadt.fullBounds(sym).nn
484-
pickleSymRef(sym)
485-
pickleType(lo)
486-
pickleType(hi)
487-
}
488-
pickleTree(expr)
489-
}
490475
case tree @ If(cond, thenp, elsep) =>
491476
writeByte(IF)
492477
withLength {
@@ -511,9 +496,23 @@ class TreePickler(pickler: TastyPickler) {
511496
else pickleTree(selector)
512497
tree.cases.foreach(pickleTree)
513498
}
514-
case CaseDef(pat, guard, rhs) =>
499+
case tree @ CaseDef(pat, guard, rhs) =>
515500
writeByte(CASEDEF)
516-
withLength { pickleTree(pat); pickleTree(rhs); pickleTreeUnlessEmpty(guard) }
501+
withLength {
502+
for (symbols, nestingLevel) <- tree.gadt.inputs do
503+
writeByte(CONSTRAINT)
504+
withLength {
505+
writeInt(nestingLevel)
506+
for sym <- symbols do
507+
val TypeBounds(lo, hi) = tree.gadt.fullBounds(sym).nn
508+
pickleSymRef(sym)
509+
pickleType(lo)
510+
pickleType(hi)
511+
}
512+
pickleTree(pat)
513+
pickleTree(rhs)
514+
pickleTreeUnlessEmpty(guard)
515+
}
517516
case Return(expr, from) =>
518517
writeByte(RETURN)
519518
withLength { pickleSymRef(from.symbol); pickleTreeUnlessEmpty(expr) }

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,27 +1237,6 @@ class TreeUnpickler(reader: TastyReader,
12371237
skipTree()
12381238
readStats(ctx.owner, end,
12391239
(stats, ctx) => Block(stats, exprReader.readTerm()(using ctx)))
1240-
case GADTEXPR =>
1241-
val gadt = EmptyGadtConstraint.fresh
1242-
while nextByte == CONSTRAINT do
1243-
readByte()
1244-
val end = readEnd()
1245-
val nestingLevel = readInt()
1246-
val constraints = until(end)((readSymRef(), readType(), readType()))
1247-
gadt.addToConstraint(constraints.map(_._1), nestingLevel)
1248-
for (sym, lo, hi) <- constraints do
1249-
if (sym.typeRef <:< lo)(using ctx.withGadt(gadt)) then
1250-
// add in reverse order so that unification runs in the right direction (keep sym)
1251-
// for a counter-example: say the symbol is c: b and the bound is b
1252-
// if we add c >: b it will unify to b: c not c: b
1253-
gadt.addBound(sym, hi, isUpper = true)
1254-
gadt.addBound(sym, lo, isUpper = false)
1255-
else
1256-
gadt.addBound(sym, lo, isUpper = false)
1257-
gadt.addBound(sym, hi, isUpper = true)
1258-
end while
1259-
val expr = inContext(ctx.withGadt(gadt))(readTerm())
1260-
GadtExpr(gadt, expr)
12611240
case INLINED =>
12621241
val exprReader = fork
12631242
skipTree()
@@ -1455,10 +1434,32 @@ class TreeUnpickler(reader: TastyReader,
14551434
val start = currentAddr
14561435
assert(readByte() == CASEDEF)
14571436
val end = readEnd()
1458-
val pat = readTerm()
1459-
val rhs = readTerm()
1460-
val guard = ifBefore(end)(readTerm(), EmptyTree)
1461-
setSpan(start, CaseDef(pat, guard, rhs))
1437+
val originalCtx = ctx
1438+
val gadt = if nextByte == CONSTRAINT then EmptyGadtConstraint.fresh else originalCtx.gadt
1439+
while nextByte == CONSTRAINT do
1440+
readByte()
1441+
val end = readEnd()
1442+
val nestingLevel = readInt()
1443+
val constraints = until(end)((readSymRef(), readType(), readType()))
1444+
gadt.addToConstraint(constraints.map(_._1), nestingLevel)
1445+
for (sym, lo, hi) <- constraints do
1446+
if (sym.typeRef <:< lo)(using ctx.withGadt(gadt)) then
1447+
// add in reverse order so that unification runs in the right direction (keep sym)
1448+
// for a counter-example: say the symbol is c: b and the bound is b
1449+
// if we add c >: b it will unify to b: c not c: b
1450+
gadt.addBound(sym, hi, isUpper = true)
1451+
gadt.addBound(sym, lo, isUpper = false)
1452+
else
1453+
gadt.addBound(sym, lo, isUpper = false)
1454+
gadt.addBound(sym, hi, isUpper = true)
1455+
end while
1456+
inContext(ctx.withGadt(gadt)) {
1457+
val pat = readTerm()
1458+
val rhs = readTerm()
1459+
val guard = ifBefore(end)(readTerm(), EmptyTree)
1460+
val gadt1 = if gadt eq originalCtx.gadt then EmptyGadtConstraint else gadt
1461+
setSpan(start, CaseDef(pat, guard, rhs, gadt1))
1462+
}
14621463
}
14631464

14641465
def readLater[T <: AnyRef](end: Addr, op: TreeReader => Context ?=> T)(using Context): Trees.Lazy[T] =

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
494494
else toText(sel)
495495
selTxt ~ keywordStr(" match ") ~ blockText(cases)
496496
}
497-
case CaseDef(pat, guard, body) =>
498-
keywordStr("case ") ~ inPattern(toText(pat)) ~ optText(guard)(keywordStr(" if ") ~ _) ~ " => " ~ caseBlockText(body)
497+
case cdef @ CaseDef(pat, guard, body) =>
498+
keywordStr("case ") ~ inPattern(toText(pat))
499+
~ (" ~ " ~ toText(cdef.gadt)).provided(cdef.gadt != EmptyGadtConstraint)
500+
~ optText(guard)(keywordStr(" if ") ~ _)
501+
~ " => " ~ caseBlockText(body)
499502
case Labeled(bind, expr) =>
500503
changePrec(GlobalPrec) { toText(bind.name) ~ keywordStr("[") ~ toText(bind.symbol.info) ~ keywordStr("]: ") ~ toText(expr) }
501504
case Return(expr, from) =>

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,17 @@ object Erasure {
816816
}
817817
}
818818

819+
override def typedGadtExpr(tree: untpd.GadtExpr, pt: Type)(using Context): GadtExpr = tree match {
820+
case GadtExpr(gadt, expr @ If(cond, _, _)) =>
821+
// type the condition without installing the gadt constraints
822+
// so that TypeTestsCasts can correctly check type tests
823+
val cond1 = typed(cond, defn.BooleanType)
824+
val expr1 = cpy.If(expr.withType(expr.tpe))(cond = cond1)
825+
val gadt1 = cpy.GadtExpr(tree.withType(tree.tpe))(gadt, expr1)
826+
super.typedGadtExpr(gadt1, pt)
827+
case _ => super.typedGadtExpr(tree, pt)
828+
}
829+
819830
/** Besides normal typing, this method does uncurrying and collects parameters
820831
* to anonymous functions of arity > 22.
821832
*/

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ object PatternMatcher {
147147
sealed abstract class Plan { val id: Int = nxId; nxId += 1 }
148148

149149
case class TestPlan(test: Test, var scrutinee: Tree, span: Span,
150-
var onSuccess: Plan) extends Plan {
150+
var onSuccess: Plan)(var gadt: GadtConstraint) extends Plan {
151151
override def equals(that: Any): Boolean = that match {
152152
case that: TestPlan => this.scrutinee === that.scrutinee && this.test == that.test
153153
case _ => false
@@ -164,6 +164,8 @@ object PatternMatcher {
164164
object TestPlan {
165165
def apply(test: Test, sym: Symbol, span: Span, ons: Plan): TestPlan =
166166
TestPlan(test, ref(sym), span, ons)
167+
def apply(test: Test, scr: Tree, span: Span, ons: Plan): TestPlan =
168+
TestPlan(test, scr, span, ons)(EmptyGadtConstraint)
167169
}
168170

169171
/** The different kinds of tests */
@@ -453,7 +455,11 @@ object PatternMatcher {
453455
var onSuccess: Plan = ResultPlan(cdef.body)
454456
if (!cdef.guard.isEmpty)
455457
onSuccess = TestPlan(GuardTest, cdef.guard, cdef.guard.span, onSuccess)
456-
patternPlan(scrutinee, cdef.pat, onSuccess)
458+
patternPlan(scrutinee, cdef.pat, onSuccess) match
459+
case plan: TestPlan =>
460+
plan.gadt = cdef.gadt
461+
plan
462+
case plan => plan
457463
}
458464

459465
private def matchPlan(tree: Match): Plan =
@@ -930,7 +936,8 @@ object PatternMatcher {
930936
If(conditions, emit(plan.onSuccess), unitLiteral)
931937
}
932938
}
933-
emitWithMashedConditions(plan :: Nil)
939+
val tree = emitWithMashedConditions(plan :: Nil)
940+
if plan.gadt == EmptyGadtConstraint then tree else GadtExpr(plan.gadt, tree)
934941
case LetPlan(sym, body) =>
935942
val valDef = ValDef(sym, initializer(sym).ensureConforms(sym.info), inferred = true).withSpan(sym.span)
936943
seq(valDef :: Nil, emit(body))

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,11 @@ abstract class Recheck extends Phase, SymTransformer:
266266
TypeComparer.lub(casesTypes)
267267

268268
def recheckCase(tree: CaseDef, selType: Type, pt: Type)(using Context): Type =
269-
recheck(tree.pat, selType)
270-
recheck(tree.guard, defn.BooleanType)
271-
recheck(tree.body, pt)
269+
inContext(ctx.withGadt(if tree.gadt == EmptyGadtConstraint then ctx.gadt else tree.gadt)) {
270+
recheck(tree.pat, selType)
271+
recheck(tree.guard, defn.BooleanType)
272+
recheck(tree.body, pt)
273+
}
272274

273275
def recheckReturn(tree: Return)(using Context): Type =
274276
// Avoid local pattern defined symbols in returns from matchResult blocks

0 commit comments

Comments
 (0)