Skip to content

Commit 0d77e13

Browse files
committed
Implement GadtExpr
1 parent 36c66e9 commit 0d77e13

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+431
-201
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ object desugar {
336336
// Propagate down the expected type to the leafs of the expression
337337
case Block(stats, expr) =>
338338
cpy.Block(tree)(stats, adaptToExpectedTpt(expr))
339+
case GadtExpr(gadt, expr) =>
340+
cpy.GadtExpr(tree)(gadt, adaptToExpectedTpt(expr))
339341
case If(cond, thenp, elsep) =>
340342
cpy.If(tree)(cond, adaptToExpectedTpt(thenp), adaptToExpectedTpt(elsep))
341343
case untpd.Parens(expr) =>
@@ -1630,6 +1632,7 @@ object desugar {
16301632
case Tuple(trees) => (pats corresponds trees)(isIrrefutable)
16311633
case Parens(rhs1) => matchesTuple(pats, rhs1)
16321634
case Block(_, rhs1) => matchesTuple(pats, rhs1)
1635+
case GadtExpr(_, rhs1) => matchesTuple(pats, rhs1)
16331636
case If(_, thenp, elsep) => matchesTuple(pats, thenp) && matchesTuple(pats, elsep)
16341637
case Match(_, cases) => cases forall (matchesTuple(pats, _))
16351638
case CaseDef(_, _, rhs1) => matchesTuple(pats, rhs1)

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
309309
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
310310
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
311311
case Block(_, expr) => forallResults(expr, p)
312+
case GadtExpr(_, expr) => forallResults(expr, p)
312313
case _ => p(tree)
313314
}
314315
}
@@ -1039,6 +1040,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
10391040
case Typed(expr, _) => unapply(expr)
10401041
case Inlined(_, Nil, expr) => unapply(expr)
10411042
case Block(Nil, expr) => unapply(expr)
1043+
case GadtExpr(_, expr) => unapply(expr)
10421044
case _ =>
10431045
tree.tpe.widenTermRefExpr.normalized match
10441046
case ConstantType(Constant(x)) => Some(x)

compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,20 @@ class TreeTypeMap(
123123
cpy.Block(blk)(stats1, expr1)
124124
case inlined: Inlined =>
125125
transformInlined(inlined)
126+
case GadtExpr(gadt, expr) =>
127+
val tmap = withMappedSyms(gadt.symbols.diff(substFrom))
128+
val gadt1 = tmap.rebuild(gadt)
129+
inContext(ctx.withGadt(gadt1))(cpy.GadtExpr(expr)(gadt1, tmap.transform(expr)))
130+
case cdef @ CaseDef(pat, guard, expr @ GadtExpr(gadt, rhs)) =>
131+
val patVars1 = patVars(pat)
132+
val tmap = withMappedSyms(patVars1 ::: gadt.symbols.diff(substFrom).diff(patVars1))
133+
val gadt1 = tmap.rebuild(gadt)
134+
inContext(ctx.withGadt(gadt1)) {
135+
val pat1 = tmap.transform(pat)
136+
val guard1 = tmap.transform(guard)
137+
val rhs1 = cpy.GadtExpr(expr)(gadt1, tmap.transform(rhs))
138+
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
139+
}
126140
case cdef @ CaseDef(pat, guard, rhs) =>
127141
val tmap = withMappedSyms(patVars(pat))
128142
val pat1 = tmap.transform(pat)
@@ -146,6 +160,29 @@ class TreeTypeMap(
146160
}
147161
}
148162

163+
private def rebuild(gadt: GadtConstraint)(using Context): GadtConstraint =
164+
val constraints = for sym <- gadt.symbols yield
165+
val TypeBounds(lo, hi) = gadt.fullBounds(sym).nn
166+
(sym, lo, hi)
167+
val constraints1 = constraints.mapConserve { triple =>
168+
val (sym, lo, hi) = triple
169+
val sym1 = mapOwner(sym)
170+
val lo1 = mapType(lo)
171+
val hi1 = mapType(hi)
172+
if (sym eq sym1) && (lo eq lo1) && (hi eq hi1)
173+
then triple
174+
else (sym1, lo1, hi1)
175+
}
176+
if constraints eq constraints1 then
177+
gadt
178+
else
179+
val gadt = EmptyGadtConstraint.fresh
180+
for (sym, lo, hi) <- constraints1 do
181+
gadt.addToConstraint(sym)
182+
gadt.addBound(sym, lo, false)
183+
gadt.addBound(sym, hi, true)
184+
gadt
185+
149186
override def transformStats(trees: List[tpd.Tree], exprOwner: Symbol)(using Context): List[Tree] =
150187
transformDefs(trees)._2
151188

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,12 @@ object Trees {
569569
override def isTerm: Boolean = !isType // this will classify empty trees as terms, which is necessary
570570
}
571571

572+
case class GadtExpr[-T >: Untyped] private[ast] (gadt: GadtConstraint, expr: Tree[T])(implicit @constructorOnly src: SourceFile)
573+
extends ProxyTree[T] {
574+
type ThisTree[-T >: Untyped] <: GadtExpr[T]
575+
def forwardTo: Tree[T] = expr
576+
}
577+
572578
/** if cond then thenp else elsep */
573579
case class If[-T >: Untyped] private[ast] (cond: Tree[T], thenp: Tree[T], elsep: Tree[T])(implicit @constructorOnly src: SourceFile)
574580
extends TermTree[T] {
@@ -1071,6 +1077,7 @@ object Trees {
10711077
type NamedArg = Trees.NamedArg[T]
10721078
type Assign = Trees.Assign[T]
10731079
type Block = Trees.Block[T]
1080+
type GadtExpr = Trees.GadtExpr[T]
10741081
type If = Trees.If[T]
10751082
type InlineIf = Trees.InlineIf[T]
10761083
type Closure = Trees.Closure[T]
@@ -1209,6 +1216,9 @@ object Trees {
12091216
case tree: Block if (stats eq tree.stats) && (expr eq tree.expr) => tree
12101217
case _ => finalize(tree, untpd.Block(stats, expr)(sourceFile(tree)))
12111218
}
1219+
def GadtExpr(tree: Tree)(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr = tree match
1220+
case tree: GadtExpr if (gadt eq tree.gadt) && (expr eq tree.expr) => tree
1221+
case _ => finalize(tree, untpd.GadtExpr(gadt, expr)(sourceFile(tree)))
12121222
def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = tree match {
12131223
case tree: If if (cond eq tree.cond) && (thenp eq tree.thenp) && (elsep eq tree.elsep) => tree
12141224
case tree: InlineIf => finalize(tree, untpd.InlineIf(cond, thenp, elsep)(sourceFile(tree)))
@@ -1430,6 +1440,10 @@ object Trees {
14301440
cpy.Closure(tree)(transform(env), transform(meth), transform(tpt))
14311441
case Match(selector, cases) =>
14321442
cpy.Match(tree)(transform(selector), transformSub(cases))
1443+
case GadtExpr(gadt, expr) =>
1444+
inContext(ctx.withGadt(gadt))(cpy.GadtExpr(tree)(gadt, transform(expr)))
1445+
case CaseDef(pat, guard, body @ GadtExpr(gadt, _)) =>
1446+
inContext(ctx.withGadt(gadt))(cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body)))
14331447
case CaseDef(pat, guard, body) =>
14341448
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
14351449
case Labeled(bind, expr) =>
@@ -1566,6 +1580,10 @@ object Trees {
15661580
this(this(this(x, env), meth), tpt)
15671581
case Match(selector, cases) =>
15681582
this(this(x, selector), cases)
1583+
case GadtExpr(gadt, expr) =>
1584+
inContext(ctx.withGadt(gadt))(this(x, expr))
1585+
case CaseDef(pat, guard, body @ GadtExpr(gadt, _)) =>
1586+
inContext(ctx.withGadt(gadt))(this(this(this(x, pat), guard), body))
15691587
case CaseDef(pat, guard, body) =>
15701588
this(this(this(x, pat), guard), body)
15711589
case Labeled(bind, expr) =>

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
9292
Block(stats, expr)
9393
}
9494

95+
def GadtExpr(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr =
96+
ta.assignType(untpd.GadtExpr(gadt, expr), gadt, expr)
97+
9598
def If(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If =
9699
ta.assignType(untpd.If(cond, thenp, elsep), thenp, elsep)
97100

@@ -673,6 +676,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
673676
}
674677
}
675678

679+
override def GadtExpr(tree: Tree)(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr =
680+
val tree1 = untpdCpy.GadtExpr(tree)(gadt, expr)
681+
tree match
682+
case tree: GadtExpr if expr.tpe eq tree.expr.tpe => tree1.withTypeUnchecked(tree.tpe)
683+
case _ => ta.assignType(tree1, gadt, expr)
684+
676685
override def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = {
677686
val tree1 = untpdCpy.If(tree)(cond, thenp, elsep)
678687
tree match {

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
382382
def NamedArg(name: Name, arg: Tree)(implicit src: SourceFile): NamedArg = new NamedArg(name, arg)
383383
def Assign(lhs: Tree, rhs: Tree)(implicit src: SourceFile): Assign = new Assign(lhs, rhs)
384384
def Block(stats: List[Tree], expr: Tree)(implicit src: SourceFile): Block = new Block(stats, expr)
385+
def GadtExpr(gadt: GadtConstraint, expr: Tree)(implicit src: SourceFile): GadtExpr = new GadtExpr(gadt, expr)
385386
def If(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new If(cond, thenp, elsep)
386387
def InlineIf(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new InlineIf(cond, thenp, elsep)
387388
def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt)

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ object Contexts {
539539
case None => fresh.dropProperty(key)
540540
}
541541

542+
final def withGadt(gadt: GadtConstraint): Context =
543+
if this.gadt eq gadt then this else fresh.setGadt(gadt)
544+
542545
def typer: Typer = this.typeAssigner match {
543546
case typer: Typer => typer
544547
case _ => new Typer

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ sealed abstract class GadtConstraint extends Showable {
3131
*
3232
* @see [[ConstraintHandling.addToConstraint]]
3333
*/
34-
def addToConstraint(syms: List[Symbol])(using Context): Boolean
34+
def addToConstraint(syms: List[Symbol], nestingLevel: Int)(using Context): Boolean
35+
def addToConstraint(syms: List[Symbol])(using Context): Boolean = addToConstraint(syms, ctx.nestingLevel)
3536
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)
3637

3738
/** Further constrain a symbol already present in the constraint. */
@@ -49,14 +50,17 @@ sealed abstract class GadtConstraint extends Showable {
4950
/** See [[ConstraintHandling.approximation]] */
5051
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type
5152

53+
def remove(sym: Symbol)(using Context): Unit
54+
5255
def symbols: List[Symbol]
56+
def inputs: List[(List[Symbol], Int)]
5357

5458
def fresh: GadtConstraint
5559

5660
/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
5761
def restore(other: GadtConstraint): Unit
5862

59-
def debugBoundsDescription(using Context): String
63+
def eql(that: GadtConstraint): Boolean
6064
}
6165

6266
final class ProperGadtConstraint private(
@@ -88,7 +92,7 @@ final class ProperGadtConstraint private(
8892
// the case where they're valid, so no approximating is needed.
8993
rawBound
9094

91-
override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
95+
override def addToConstraint(params: List[Symbol], nestingLevel: Int)(using Context): Boolean = {
9296
import NameKinds.DepParamName
9397

9498
val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
@@ -126,15 +130,15 @@ final class ProperGadtConstraint private(
126130
)
127131

128132
val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) =>
129-
val tv = TypeVar(paramRef, creatorState = null)
133+
val tv = TypeVar(paramRef, creatorState = null, nestingLevel)
130134
mapping = mapping.updated(sym, tv)
131135
reverseMapping = reverseMapping.updated(tv.origin, sym)
132136
tv
133137
}
134138

135139
// The replaced symbols are picked up here.
136140
addToConstraint(poly1, tvars)
137-
.showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts)
141+
.showing(i"added to constraint: [$poly1] $params%, % gadt = $this", gadts)
138142
}
139143

140144
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
@@ -219,8 +223,22 @@ final class ProperGadtConstraint private(
219223
res
220224
}
221225

226+
override def remove(sym: Symbol)(using Context): Unit =
227+
mapping(sym) match
228+
case tv: TypeVar =>
229+
mapping = mapping.remove(sym)
230+
reverseMapping = reverseMapping.remove(tv.origin)
231+
constraint = constraint.replace(tv.origin, sym.typeRef)
232+
case null =>
233+
222234
override def symbols: List[Symbol] = mapping.keys
223235

236+
override def inputs: List[(List[Symbol], Int)] =
237+
constraint.domainLambdas.flatMap { tl =>
238+
val syms = tl.paramRefs.flatMap(reverseMapping(_).toOption)
239+
syms.headOption.map(sym1 => (syms, mapping(sym1).nn.initNestingLevel))
240+
}
241+
224242
override def fresh: GadtConstraint = new ProperGadtConstraint(
225243
myConstraint,
226244
mapping,
@@ -291,17 +309,15 @@ final class ProperGadtConstraint private(
291309

292310
override def constr = gadtsConstr
293311

294-
override def toText(printer: Printer): Texts.Text = constraint.toText(printer)
312+
override def eql(that: GadtConstraint): Boolean = (this eq that) || that.match
313+
case that: ProperGadtConstraint =>
314+
myConstraint == that.myConstraint
315+
&& mapping == that.mapping
316+
&& reverseMapping == that.reverseMapping
317+
&& wasConstrained == that.wasConstrained
318+
case _ => false
295319

296-
override def debugBoundsDescription(using Context): String = {
297-
val sb = new mutable.StringBuilder
298-
sb ++= constraint.show
299-
sb += '\n'
300-
mapping.foreachBinding { case (sym, _) =>
301-
sb ++= i"$sym: ${fullBounds(sym)}\n"
302-
}
303-
sb.result
304-
}
320+
override def toText(printer: Printer): Texts.Text = printer.toText(this)
305321
}
306322

307323
@sharable object EmptyGadtConstraint extends GadtConstraint {
@@ -314,18 +330,21 @@ final class ProperGadtConstraint private(
314330

315331
override def contains(sym: Symbol)(using Context) = false
316332

317-
override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
333+
override def addToConstraint(params: List[Symbol], nestingLevel: Int)(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
318334
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")
319335

320336
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
321337

338+
override def remove(sym: Symbol)(using Context): Unit = ()
339+
322340
override def symbols: List[Symbol] = Nil
341+
override def inputs: List[(List[Symbol], Int)] = Nil
323342

324343
override def fresh = new ProperGadtConstraint
325344
override def restore(other: GadtConstraint): Unit =
326345
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")
327346

328-
override def debugBoundsDescription(using Context): String = "EmptyGadtConstraint"
347+
override def eql(that: GadtConstraint): Boolean = (this eq that) || that == EmptyGadtConstraint
329348

330-
override def toText(printer: Printer): Texts.Text = "EmptyGadtConstraint"
349+
override def toText(printer: Printer): Texts.Text = printer.toText(this)
331350
}

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
261261
val assumeInvariantRefinement =
262262
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
263263

264-
trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
264+
trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res gadt = ${ctx.gadt}") {
265265
(tp, pt) match {
266266
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
267267
val saved = state.nn.constraint

0 commit comments

Comments
 (0)