Skip to content

Adding full support of path-dependent GADT reasoning #13475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
876910f
Change to use TypeRef to represent type parameters
Linyxus Aug 22, 2021
cd9d872
Add function to retrieve all type member symbols
Linyxus Aug 22, 2021
8ceed62
Add GADT logic for handling general typerefs
Linyxus Aug 23, 2021
e9150dc
Implement GADT type inference for path-dependent types
Linyxus Aug 23, 2021
77a879f
Support path-dependent GADT reasoning for upper bounds
Linyxus Aug 23, 2021
f060c9b
Fix type comparison triggered by onGadtBounds
Linyxus Aug 24, 2021
0f52067
Add footprint tracking for TypeRefs
Linyxus Aug 31, 2021
c1b4b8c
Adapt addLess for TypeRefs
Linyxus Sep 1, 2021
de5c91f
Register bound if it is a PDT
Linyxus Sep 1, 2021
0855528
Add basic testcase for path-dependent GADT
Linyxus Sep 2, 2021
5b76853
Update subsumption check considering just-in-time registration
Linyxus Sep 2, 2021
ee3febe
Add pos and neg tests for necessaryEither
Linyxus Sep 2, 2021
21eabc1
Support recording scrutinee path in GadtConstraint
Linyxus Sep 2, 2021
e26ed48
Support type member reasoning
Linyxus Sep 2, 2021
5a6d471
Stage work state
Linyxus Sep 3, 2021
f88b90b
Clear trace
Linyxus Sep 4, 2021
7b0d5f7
Support structural type member reasoning
Linyxus Sep 4, 2021
ebfcb0e
Remove useless code handling unbound patterns
Linyxus Sep 4, 2021
d90657d
Merge branch 'feature/refactor-pdgadt-typeref' into feature/refactor-…
Linyxus Sep 4, 2021
358e1e3
Fix merge errors
Linyxus Sep 4, 2021
72aa934
Fix NoType bounds
Linyxus Sep 4, 2021
c249890
Retrieve type members denotations from singleton type for correct bounds
Linyxus Sep 4, 2021
6d8e66e
Fix path-dependent type interdependency handling
Linyxus Sep 5, 2021
cd368fa
Refactor type member reasoning
Linyxus Sep 5, 2021
ff2922c
Refactor path-dependent structural GADT reasoning
Linyxus Sep 5, 2021
c90bc54
Add pos test for structural pdgadt
Linyxus Sep 5, 2021
4665bd2
Remove test i2941
Linyxus Sep 5, 2021
0ed63f6
Add negative tests
Linyxus Sep 6, 2021
78d0374
Fix issues
Linyxus Sep 6, 2021
c357d42
Remove unused scrutinee state
Linyxus Sep 6, 2021
e1e7d66
Tweak docstring
Linyxus Sep 6, 2021
18930cc
Revert "Remove test i2941"
Linyxus Sep 28, 2021
d317862
Strip lazyref in GADT bounds
Linyxus Oct 3, 2021
93e70e0
Directly use scrutinee and pattern paths when available
Linyxus Oct 3, 2021
9519f8a
Renaming for readability
Linyxus Oct 3, 2021
ed06974
Tweak the place for triggering path-dependent GADT reasoning
Linyxus Oct 3, 2021
0b40c94
Update tests
Linyxus Oct 7, 2021
1d2a7fa
Update neg test
Linyxus Oct 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 309 additions & 25 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,4 +282,44 @@ trait PatternTypeConstrainer { self: TypeComparer =>
}
}
}

/** Derive GADT bounds on type members of the scrutinee and the pattern. */
def constrainTypeMembers(scrut: Type, pat: Type, realScrutPath: TermRef, realPatPath: TermRef): Boolean = trace(i"constraining type members $scrut >:< $pat", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
val saved = state.constraint
val savedGadt = ctx.gadt.fresh

val scrutPath = if realScrutPath eq null then SkolemType(scrut) else realScrutPath
val patPath = if realPatPath eq null then SkolemType(pat) else realPatPath

val scrutPDTs = ctx.gadt.addAllPDTsFrom(scrutPath)
val patPDTs = ctx.gadt.addAllPDTsFrom(patPath)

if scrutPDTs.eq(null) || patPDTs.eq(null) then
ctx.gadt.restore(savedGadt)
return true

val scrutSyms = Map.from {
scrutPDTs map { pdt => pdt.symbol.name -> pdt }
}
val patSyms = Map.from {
patPDTs map { pdt => pdt.symbol.name -> pdt }
}

val shared = scrutSyms.keySet intersect patSyms.keySet

val result = shared forall { name =>
val tprS = scrutSyms(name)
val tprP = patSyms(name)
isSubType(tprS, tprP) && isSubType(tprP, tprS)
}

if !result then
constraint = saved
ctx.gadt.restore(savedGadt)
// else
// if realScrutPath ne null then ctx.gadt.replacePath(scrutPath, realScrutPath)
// if realPatPath ne null then ctx.gadt.replacePath(patPath, realPatPath)

result
}
}
89 changes: 69 additions & 20 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym)
protected def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = false)
protected def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = true)
protected def gadtBounds(tpr: TypeRef)(using Context) = ctx.gadt.bounds(tpr)
protected def gadtAddLowerBound(tpr: TypeRef, b: Type): Boolean = ctx.gadt.addBound(tpr, b, isUpper = false)
protected def gadtAddUpperBound(tpr: TypeRef, b: Type): Boolean = ctx.gadt.addBound(tpr, b, isUpper = true)

protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying

Expand Down Expand Up @@ -179,6 +182,23 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
private inline def inFrozenGadtAndConstraint[T](inline op: T): T =
inFrozenGadtIf(true)(inFrozenConstraint(op))

private def canRegisterPDT: Boolean =
ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint

private def tryRegisterPDT(tpr: TypeRef): Boolean =
canRegisterPDT
&& ctx.gadt.isConstrainablePDT(tpr)
&& ctx.gadt.addPDT(tpr)

extension (tpr: TypeRef)
private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean =
val gbounds = gadtBounds(tpr) match
case null =>
if tryRegisterPDT(tpr) then gadtBounds(tpr) else null
case gbounds => gbounds

gbounds != null && op(gbounds)

extension (sym: Symbol)
private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean =
val bounds = gadtBounds(sym)
Expand Down Expand Up @@ -500,20 +520,23 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match {
case info2: TypeBounds =>
def compareGADT: Boolean =
tp2.symbol.onGadtBounds(gbounds2 =>
isSubTypeWhenFrozen(tp1, gbounds2.lo)
|| tp1.match
case tp1: NamedType if ctx.gadt.contains(tp1.symbol) =>
// Note: since we approximate constrained types only with their non-param bounds,
// we need to manually handle the case when we're comparing two constrained types,
// one of which is constrained to be a subtype of another.
// We do not need similar code in fourthTry, since we only need to care about
// comparing two constrained types, and that case will be handled here first.
ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol)
case _ => false
|| narrowGADTBounds(tp2, tp1, approx, isUpper = false))
&& (isBottom(tp1) || GADTusage(tp2.symbol))
def compareGADT: Boolean = tp2 match
case tp2: TypeRef =>
tp2.onGadtBounds(gbounds2 =>
isSubTypeWhenFrozen(tp1, gbounds2.lo)
|| tp1.match
case tp1: TypeRef if ctx.gadt.contains(tp1) =>
// Note: since we approximate constrained types only with their non-param bounds,
// we need to manually handle the case when we're comparing two constrained types,
// one of which is constrained to be a subtype of another.
// We do not need similar code in fourthTry, since we only need to care about
// comparing two constrained types, and that case will be handled here first.
ctx.gadt.isLess(tp1, tp2) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol)
case _ => false
|| narrowGADTBounds(tp2, tp1, approx, isUpper = false)
) && (isBottom(tp1) || GADTusage(tp2.symbol))
case _ => false
end compareGADT

isSubApproxHi(tp1, info2.lo) || compareGADT || tryLiftedToThis2 || fourthTry

Expand Down Expand Up @@ -765,7 +788,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
tp1.info match {
case TypeBounds(_, hi1) =>
def compareGADT =
tp1.symbol.onGadtBounds(gbounds1 =>
tp1.onGadtBounds(gbounds1 =>
isSubTypeWhenFrozen(gbounds1.hi, tp2)
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true))
&& (tp2.isAny || GADTusage(tp1.symbol))
Expand Down Expand Up @@ -1878,14 +1901,22 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
* `bound` as an upper or lower bound (which depends on `isUpper`).
* Test that the resulting bounds are still satisfiable.
*/
private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = {
private def narrowGADTBounds(tr: TypeRef, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = {
val boundImprecise = approx.high || approx.low
ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && {
val tparam = tr.symbol
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
if (bound.isRef(tparam)) false
else if (isUpper) gadtAddUpperBound(tparam, bound)
else gadtAddLowerBound(tparam, bound)
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam) && !ctx.gadt.isConstrainablePDT(bound)}")

def registerPDTBound(): Boolean = bound match
case bound: TypeRef =>
ctx.gadt.isConstrainablePDT(bound) && !ctx.gadt.contains(bound) && tryRegisterPDT(bound)
case _ => false

registerPDTBound()

if (bound.isRef(tparam) && !ctx.gadt.isConstrainablePDT(bound)) false
else if (isUpper) gadtAddUpperBound(tr, bound)
else gadtAddLowerBound(tr, bound)
}
}

Expand Down Expand Up @@ -2800,6 +2831,9 @@ object TypeComparer {
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))

def constrainTypeMembers(scrut: Type, pat: Type, scrutPath: TermRef, patPath: TermRef)(using Context): Boolean =
comparing(_.constrainTypeMembers(scrut, pat, scrutPath, patPath))

def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String =
comparing(_.explained(op, header))

Expand Down Expand Up @@ -2829,16 +2863,31 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
super.gadtBounds(sym)
}

override def gadtBounds(tpr: TypeRef)(using Context): TypeBounds = {
if (tpr.exists) footprint += tpr
super.gadtBounds(tpr)
}

override def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = {
if (sym.exists) footprint += sym.typeRef
super.gadtAddLowerBound(sym, b)
}

override def gadtAddLowerBound(tpr: TypeRef, b: Type): Boolean = {
if (tpr.exists) footprint += tpr
super.gadtAddLowerBound(tpr, b)
}

override def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = {
if (sym.exists) footprint += sym.typeRef
super.gadtAddUpperBound(sym, b)
}

override def gadtAddUpperBound(tpr: TypeRef, b: Type): Boolean = {
if (tpr.exists) footprint += tpr
super.gadtAddUpperBound(tpr, b)
}

override def typeVarInstance(tvar: TypeVar)(using Context): Type = {
footprint += tvar
super.typeVarInstance(tvar)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -818,4 +818,7 @@ object TypeOps:
def nestedPairs(ts: List[Type])(using Context): Type =
ts.foldRight(defn.EmptyTupleModule.termRef: Type)(defn.PairClass.typeRef.appliedTo(_, _))

def abstractTypeMemberSymbols(tp: Type)(using Context): List[Symbol] =
tp.abstractTypeMembers.toList map (_.symbol)

end TypeOps
18 changes: 18 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,23 @@ class Typer extends Namer
}
val pat1 = indexPattern(tree).transform(pat)
val guard1 = typedExpr(tree.guard, defn.BooleanType)

// Trigger path-dependent GADT reasoning
val scrutType = sel.tpe.widen
val patType = pat1.tpe.widen match
case AndType(tp1, tp2) => tp2
case tp => tp
val scrutPath = sel.tpe match
case tp: TermRef => tp
case _ => null
val patPath = pat1.tpe match
case tp: TermRef => tp
case _ => null

withMode(Mode.GadtConstraintInference)
(TypeComparer.constrainTypeMembers(scrutType, patType, scrutPath, patPath))
(using gadtCtx)

var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1), pt1, ctx.scope.toList)
if ctx.gadt.nonEmpty then
// Store GADT constraint to later retrieve it (in PostTyper, for now).
Expand All @@ -1647,6 +1664,7 @@ class Typer extends Namer
}

val pat1 = typedPattern(tree.pat, wideSelType)(using gadtCtx)

caseRest(pat1)(
using Nullables.caseContext(sel, pat1)(
using gadtCtx.fresh.setNewScope))
Expand Down
12 changes: 12 additions & 0 deletions tests/neg/class-pdgadt.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
trait P:
case class A()
type B

enum SUB[-A, +B]:
case EQ[X]() extends SUB[X, X]

def f(p: P, e1: SUB[p.A, Int], e2: SUB[p.B, Int]) = e1 match
case SUB.EQ() => e2 match
case SUB.EQ() =>
val t1: Int = ??? : p.A // error
val t2: Int = ??? : p.B
11 changes: 11 additions & 0 deletions tests/neg/necessary-pdgadt.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/* N <: M */
trait M
trait N

enum SUB[-A, +B]:
case Ev[X]() extends SUB[X, X]
trait P { type T }

def f(p: P, e: SUB[p.T, N | M]) = e match
case SUB.Ev() =>
(??? : p.T) : N // error
9 changes: 9 additions & 0 deletions tests/neg/object-pdgadt.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
object P:
type T

enum SUB[-A, +B]:
case EQ[X]() extends SUB[X, X]

def f(p: P.type, e: SUB[P.T, Int]) = p match
case _: P.type =>
val t0: Int = ??? : p.T // error
14 changes: 7 additions & 7 deletions tests/neg/structural-gadt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ object Test {
def foo[A](e: Expr { type T = A }) = e match {
case _: IntLit =>
val a: A = 0 // error
val i: Int = ??? : A // limitation // error
val i: Int = ??? : A

case _: Expr { type T <: Int } =>
val a: A = 0 // error
val i: Int = ??? : A // limitation // error

case _: IntExpr =>
val a: A = 0 // limitation // error
val i: Int = ??? : A // limitation // error
val a: A = 0
val i: Int = ??? : A

case _: Expr { type T = Int } =>
val a: A = 0 // limitation // error
val i: Int = ??? : A // limitation // error
val a: A = 0
val i: Int = ??? : A
}

def bar[A](e: Expr { type T <: A }) = e match {
Expand All @@ -36,11 +36,11 @@ object Test {
val i: Int = ??? : A // error

case _: IntExpr =>
val a: A = 0 // limitation // error
val a: A = 0
val i: Int = ??? : A // error

case _: Expr { type T = Int } =>
val a: A = 0 // limitation // error
val a: A = 0
val i: Int = ??? : A // error
}

Expand Down
8 changes: 4 additions & 4 deletions tests/neg/structural-recursive-both1-gadt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ object Test {
def foo[A](e: IndirectExprExact[A]) = e match {
case _: AltIndirectIntLit =>
val a: A = 0 // error
val i: Int = ??? : A // limitation // error
val i: Int = ??? : A

case _: AltIndirectExprSub[Int] =>
val a: A = 0 // error
val i: Int = ??? : A // limitation // error
val i: Int = ??? : A

case _: AltIndirectExprSub2[Int] =>
val a: A = 0 // error
val i: Int = ??? : A // limitation // error

case _: AltIndirectIntExpr =>
val a: A = 0 // limitation // error
val i: Int = ??? : A // limitation // error
val a: A = 0
val i: Int = ??? : A

case _: AltIndirectExprExact[Int] =>
val a: A = 0 // limitation // error
Expand Down
8 changes: 4 additions & 4 deletions tests/neg/structural-recursive-both2-gadt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ object Test {
def foo[A](e: AltIndirectExprExact[A]) = e match {
case _: IndirectIntLit =>
val a: A = 0 // error
val i: Int = ??? : A // limitation // error
val i: Int = ??? : A

case _: IndirectExprSub[Int] =>
val a: A = 0 // error
val i: Int = ??? : A // limitation // error
val i: Int = ??? : A

case _: IndirectExprSub2[Int] =>
val a: A = 0 // error
val i: Int = ??? : A // limitation // error

case _: IndirectIntExpr =>
val a: A = 0 // limitation // error
val i: Int = ??? : A // limitation // error
val a: A = 0
val i: Int = ??? : A

case _: IndirectExprExact[Int] =>
val a: A = 0 // limitation // error
Expand Down
Loading