diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 7d84b9892057..67f303c55ff1 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -9,13 +9,16 @@ import Symbols._ import util.SimpleIdentityMap import collection.mutable import printing._ +import TypeOps.abstractTypeMemberSymbols import scala.annotation.internal.sharable +import dotty.tools.dotc.core.Denotations.Denotation /** Represents GADT constraints currently in scope */ sealed abstract class GadtConstraint extends Showable { /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ def bounds(sym: Symbol)(using Context): TypeBounds + def bounds(tp: TypeRef)(using Context): TypeBounds /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. * @@ -23,9 +26,11 @@ sealed abstract class GadtConstraint extends Showable { * Using this in isSubType can lead to infinite recursion. Consider `bounds` instead. */ def fullBounds(sym: Symbol)(using Context): TypeBounds + def fullBounds(tp: TypeRef)(using Context): TypeBounds /** Is `sym1` ordered to be less than `sym2`? */ def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean + def isLess(tpr1: TypeRef, tpr2: TypeRef)(using Context): Boolean /** Add symbols to constraint, correctly handling inter-dependencies. * @@ -36,12 +41,30 @@ sealed abstract class GadtConstraint extends Showable { /** Further constrain a symbol already present in the constraint. */ def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + def addBound(tpr: TypeRef, bound: Type, isUpper: Boolean)(using Context): Boolean /** Is the symbol registered in the constraint? * * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. */ def contains(sym: Symbol)(using Context): Boolean + def contains(tp: TypeRef)(using Context): Boolean + + /** Is the type a constrainable path-dependent type? + */ + def isConstrainablePDT(tp: Type)(using Context): Boolean + + /** Add path-dependent type to constraint. + */ + def addPDT(tp: Type)(using Context): Boolean + + /** All all constrainable path-dependent type originating from the given path to constraint. + */ + def addAllPDTsFrom(path: Type)(using Context): List[TypeRef] + + /** Replace all paths of PDTs with a specific path to another type. + */ + def replacePath(from: Type, to: Type)(using Context): Unit def isEmpty: Boolean final def nonEmpty: Boolean = !isEmpty @@ -59,8 +82,8 @@ sealed abstract class GadtConstraint extends Showable { final class ProperGadtConstraint private( private var myConstraint: Constraint, - private var mapping: SimpleIdentityMap[Symbol, TypeVar], - private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], + private var mapping: SimpleIdentityMap[TypeRef, TypeVar], + private var reverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef] ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -70,15 +93,249 @@ final class ProperGadtConstraint private( reverseMapping = SimpleIdentityMap.empty ) - /** Exposes ConstraintHandling.subsumes */ + /** Whether `left` subsumes `right`? + * + * `left` and `right` both stem from the constraint `pre`, with different type reasoning performed, + * during which new types might be registered in GadtConstraint. This function will take such newly + * registered types into consideration. + */ def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { + // When new types are registered after pre, for left to subsume right, it should contain all types + // newly registered in right. + def checkSubsumes(c1: Constraint, c2: Constraint, pre: Constraint): Boolean = { + if (c2 eq pre) true + else if (c1 eq pre) false + else { + val saved = constraint + + /** Compute type parameters in c1 added after `pre` + */ + val params1 = c1.domainParams.toSet + val params2 = c2.domainParams.toSet + val preParams = pre.domainParams.toSet + val newParams1 = params1.diff(preParams) + val newParams2 = params2.diff(preParams) + + def checkNewParams: Boolean = (left, right) match { + case (left: ProperGadtConstraint, right: ProperGadtConstraint) => + newParams2 forall { p2 => + val tp2 = right.externalize(p2) + left.tvarOfType(tp2) != null + } + case _ => true + } + + // bridge between the newly-registered types in c2 and c1 + val (bridge1, bridge2) = { + var bridge1: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + var bridge2: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + + (left, right) match { + // only meaningful when both constraints are proper + case (left: ProperGadtConstraint, right: ProperGadtConstraint) => + newParams1 foreach { p1 => + val tp1 = left.externalize(p1) + right.tvarOfType(tp1) match { + case null => + case tvar2 => + bridge1 = bridge1.updated(p1, tvar2.origin) + bridge2 = bridge2.updated(tvar2.origin, p1) + } + } + case _ => + } + + (bridge1, bridge2) + } + + def bridgeParam(bridge: SimpleIdentityMap[TypeParamRef, TypeParamRef])(tpr: TypeParamRef): TypeParamRef = bridge(tpr) match { + case null => tpr + case tpr1 => tpr1 + } + val bridgeParam1 = bridgeParam(bridge1) + val bridgeParam2 = bridgeParam(bridge2) + + try { + // checks existing type parameters in `pre` + def existing: Boolean = pre.forallParams { p => + c1.contains(p) && + c2.upper(p).forall { q => + c1.isLess(p, bridgeParam2(q)) + } && isSubTypeWhenFrozen(c1.nonParamBounds(p), c2.nonParamBounds(p)) + } + + // checks new type parameters in `c1` + def added: Boolean = newParams1 forall { p1 => + bridge1(p1) match { + case null => + // p1 is in `left` but not in `right` + true + case p2 => + c2.upper(p2).forall { q => + c1.isLess(p1, bridgeParam2(q)) + } && isSubTypeWhenFrozen(c1.nonParamBounds(p1), c2.nonParamBounds(p2)) + } + } + + existing && checkNewParams && added + } finally constraint = saved + } + } + def extractConstraint(g: GadtConstraint) = g match { case s: ProperGadtConstraint => s.constraint case EmptyGadtConstraint => OrderingConstraint.empty } - subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) + + checkSubsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) } + override def isConstrainablePDT(tp: Type)(using Context): Boolean = tp match + case tp @ TypeRef(prefix, des) => isConstrainablePath(prefix) && ! tp.symbol.is(Flags.Opaque) + case _ => false + + /** Whether type members of the given path is constrainable? + * + * Package's and module's type members will not be constrained. + */ + private def isConstrainablePath(path: Type)(using Context): Boolean = path match + case path: TermRef + if !path.symbol.is(Flags.Package) + && !path.symbol.is(Flags.Module) + && !path.classSymbol.is(Flags.Package) + && !path.classSymbol.is(Flags.Module) + => true + case _: SkolemType + if !path.classSymbol.is(Flags.Package) + && !path.classSymbol.is(Flags.Module) + => true + case _ => false + + override def addPDT(tp: Type)(using Context): Boolean = + assert(isConstrainablePDT(tp), i"Type $tp is not a constrainable path-dependent type.") + tp match + case TypeRef(prefix: TermRef, _) => addTypeMembersOf(prefix).nonEmpty + case _ => false + + override def addAllPDTsFrom(path: Type)(using Context): List[TypeRef] = + addTypeMembersOf(path) match { + case None => null + case Some(m) => + m.values.toList map { tv => externalize(tv.origin).asInstanceOf[TypeRef] } + } + + override def replacePath(from: Type, to: Type)(using Context): Unit = + val originalPairs = mapping.toList + + originalPairs foreach { (tpr, tvar) => + if tpr.prefix eq from then + val extType = TypeRef(to, tpr.symbol) + mapping = mapping.updated(extType, tvar) + mapping = mapping.remove(tpr) + reverseMapping = reverseMapping.updated(tvar.origin, extType) + } + + /** Find all constrainable type member denotations of the given type. + * + * All abstract but not opaque type members are returned. + * Note that we return denotation here, since the bounds of the type member + * depend on the context (e.g. applied type parameters). + */ + private def constrainableTypeMembers(tp: Type)(using Context): List[Denotation] = + tp.typeMembers.toList filter { denot => + val denot1 = tp.nonPrivateMember(denot.name) + val tb = denot.info + + def isConstrainableAlias: Boolean = tb match + case TypeAlias(_) => true + case _ => false + + (denot1.symbol.is(Flags.Deferred) || isConstrainableAlias) + && !denot1.symbol.is(Flags.Opaque) + && !denot1.symbol.isClass + } + + private def addTypeMembersOf(path: Type)(using Context): Option[Map[Symbol, TypeVar]] = + import NameKinds.DepParamName + + if !isConstrainablePath(path) then return None + + val pathType = path.widen + val typeMembers = constrainableTypeMembers(path) + + if typeMembers.isEmpty then return Some(Map.empty) + + val typeMemberSyms: List[Symbol] = typeMembers map (_.symbol) + + val poly1 = PolyType(typeMembers map { d => DepParamName.fresh(d.name.toTypeName) })( + pt => typeMembers map { typeMember => + def substDependentSyms(tp: Type, isUpper: Boolean)(using Context): Type = { + def loop(tp: Type): Type = tp match + case tp @ AndType(tp1, tp2) if !isUpper => + tp.derivedAndOrType(loop(tp1), loop(tp2)) + case tp @ OrType(tp1, tp2) if isUpper => + tp.derivedOrType(loop(tp1), loop(tp2)) + case tp @ TypeRef(prefix, des) if prefix eq path => + typeMemberSyms indexOf tp.symbol match + case -1 => tp + case idx => pt.paramRefs(idx) + case tp @ TypeRef(_: RecThis, des) => + typeMemberSyms indexOf tp.symbol match + case -1 => tp + case idx => pt.paramRefs(idx) + case tp: TypeRef => + mapping(tp) match { + case tv: TypeVar => tv.origin + case null => tp + } + case tp => tp + + loop(tp) + } + + val tb = typeMember.info.bounds + + def stripLazyRef(tp: Type): Type = tp match + case tp @ RefinedType(parent, name, tb) => + tp.derivedRefinedType(stripLazyRef(parent), name, stripLazyRef(tb)) + case tp: RecType => + tp.derivedRecType(stripLazyRef(tp.parent)) + case tb: TypeBounds => + tb.derivedTypeBounds(stripLazyRef(tb.lo), stripLazyRef(tb.hi)) + case ref: LazyRef => + ref.stripLazyRef + case _ => tp + + val tb1: TypeBounds = stripLazyRef(tb).asInstanceOf + + tb1.derivedTypeBounds( + lo = substDependentSyms(tb1.lo, isUpper = false), + hi = substDependentSyms(tb1.hi, isUpper = true) + ) + }, + pt => defn.AnyType + ) + + val tvars = typeMemberSyms lazyZip poly1.paramRefs map { (sym, paramRef) => + val tv = TypeVar(paramRef, creatorState = null) + + val externalType = TypeRef(path, sym) + mapping = mapping.updated(externalType, tv) + reverseMapping = reverseMapping.updated(tv.origin, externalType) + + tv + } + + def register = + addToConstraint(poly1, tvars) + .showing(i"added to constraint: [$poly1] $typeMembers%, %\n$debugBoundsDescription", gadts) + + if register then + Some(Map.from(typeMemberSyms lazyZip tvars)) + else + None + end addTypeMembersOf + override def addToConstraint(params: List[Symbol])(using Context): Boolean = { import NameKinds.DepParamName @@ -97,7 +354,7 @@ final class ProperGadtConstraint private( case tp: NamedType => params.indexOf(tp.symbol) match { case -1 => - mapping(tp.symbol) match { + mapping(tp.symbol.typeRef) match { case tv: TypeVar => tv.origin case null => tp } @@ -118,8 +375,8 @@ final class ProperGadtConstraint private( val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) => val tv = TypeVar(paramRef, creatorState = null) - mapping = mapping.updated(sym, tv) - reverseMapping = reverseMapping.updated(tv.origin, sym) + mapping = mapping.updated(sym.typeRef, tv) + reverseMapping = reverseMapping.updated(tv.origin, sym.typeRef) tv } @@ -128,7 +385,7 @@ final class ProperGadtConstraint private( .showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts) } - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { + override def addBound(tpr: TypeRef, bound: Type, isUpper: Boolean)(using Context): Boolean = { @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { case tv: TypeVar => val inst = constraint.instType(tv) @@ -136,19 +393,20 @@ final class ProperGadtConstraint private( case _ => tp } - val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match { + val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(tpr)) match { case tv: TypeVar => tv case inst => - gadts.println(i"instantiated: $sym -> $inst") + gadts.println(i"*** instantiated: $tpr -> $inst") return if (isUpper) isSub(inst, bound) else isSub(bound, inst) } val internalizedBound = bound match { - case nt: NamedType => - val ntTvar = mapping(nt.symbol) + case tpr: TypeRef => + val ntTvar = mapping(tpr) if (ntTvar ne null) stripInternalTypeVar(ntTvar) else bound case _ => bound } + ( internalizedBound match { case boundTvar: TypeVar => @@ -161,29 +419,37 @@ final class ProperGadtConstraint private( ).showing({ val descr = if (isUpper) "upper" else "lower" val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $sym $op $bound = $result" + i"adding $descr bound $tpr $op $bound = $result" }, gadts) } + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = + addBound(sym.typeRef, bound, isUpper) + override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = - constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) + isLess(sym1.typeRef, sym2.typeRef) + + override def isLess(tp1: TypeRef, tp2: TypeRef)(using Context): Boolean = + constraint.isLess(tvarOrError(tp1).origin, tvarOrError(tp2).origin) - override def fullBounds(sym: Symbol)(using Context): TypeBounds = - mapping(sym) match { + override def fullBounds(tp: TypeRef)(using Context): TypeBounds = + mapping(tp) match { case null => null case tv => fullBounds(tv.origin) // .ensuring(containsNoInternalTypes(_)) } - override def bounds(sym: Symbol)(using Context): TypeBounds = - mapping(sym) match { + override def fullBounds(sym: Symbol)(using Context): TypeBounds = fullBounds(sym.typeRef) + + override def bounds(tp: TypeRef)(using Context): TypeBounds = + mapping(tp) match { case null => null case tv => def retrieveBounds: TypeBounds = bounds(tv.origin) match { case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => - TypeAlias(reverseMapping(tpr).typeRef) + TypeAlias(reverseMapping(tpr)) case tb => tb } retrieveBounds @@ -191,7 +457,10 @@ final class ProperGadtConstraint private( //.ensuring(containsNoInternalTypes(_)) } - override def contains(sym: Symbol)(using Context): Boolean = mapping(sym) ne null + override def bounds(sym: Symbol)(using Context): TypeBounds = bounds(sym.typeRef) + + override def contains(tp: TypeRef)(using Context): Boolean = mapping(tp) ne null + override def contains(sym: Symbol)(using Context): Boolean = contains(sym.typeRef) override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = { val res = approximation(tvarOrError(sym).origin, fromBelow = fromBelow) @@ -249,12 +518,18 @@ final class ProperGadtConstraint private( private def externalize(param: TypeParamRef)(using Context): Type = reverseMapping(param) match { - case sym: Symbol => sym.typeRef + case tpr: TypeRef => tpr case null => param } - private def tvarOrError(sym: Symbol)(using Context): TypeVar = - mapping(sym).ensuring(_ ne null, i"not a constrainable symbol: $sym") + private def tvarOfType(tp: Type)(using Context): TypeVar = tp match + case tp: TypeRef => mapping(tp) + case _ => null + + private def tvarOrError(tpr: TypeRef)(using Context): TypeVar = + mapping(tpr).ensuring(_ ne null, i"not a constrainable type: $tpr") + + private def tvarOrError(sym: Symbol)(using Context): TypeVar = tvarOrError(sym.typeRef) private def containsNoInternalTypes( tp: Type, @@ -280,8 +555,8 @@ final class ProperGadtConstraint private( val sb = new mutable.StringBuilder sb ++= constraint.show sb += '\n' - mapping.foreachBinding { case (sym, _) => - sb ++= i"$sym: ${fullBounds(sym)}\n" + mapping.foreachBinding { case (tpr, _) => + sb ++= i"$tpr: ${fullBounds(tpr)}\n" } sb.result } @@ -290,14 +565,23 @@ final class ProperGadtConstraint private( @sharable object EmptyGadtConstraint extends GadtConstraint { override def bounds(sym: Symbol)(using Context): TypeBounds = null override def fullBounds(sym: Symbol)(using Context): TypeBounds = null + override def bounds(tp: TypeRef)(using Context): TypeBounds = null + override def fullBounds(tp: TypeRef)(using Context): TypeBounds = null override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") + override def isLess(tp1: TypeRef, tp2: TypeRef)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") override def isEmpty: Boolean = true override def contains(sym: Symbol)(using Context) = false + override def contains(tp: TypeRef)(using Context) = false + override def isConstrainablePDT(tp: Type)(using Context): Boolean = false + override def addPDT(tp: Type)(using Context): Boolean = false + override def addAllPDTsFrom(path: Type)(using Context): List[TypeRef] = null + override def replacePath(from: Type, to: Type)(using Context): Unit = () override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") + override def addBound(tpr: TypeRef, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 3fd5a2b9f208..70d44e2cc8e4 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -282,4 +282,44 @@ trait PatternTypeConstrainer { self: TypeComparer => } } } + + /** Derive GADT bounds on type members of the scrutinee and the pattern. */ + def constrainTypeMembers(scrut: Type, pat: Type, realScrutPath: TermRef, realPatPath: TermRef): Boolean = trace(i"constraining type members $scrut >:< $pat", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { + val saved = state.constraint + val savedGadt = ctx.gadt.fresh + + val scrutPath = if realScrutPath eq null then SkolemType(scrut) else realScrutPath + val patPath = if realPatPath eq null then SkolemType(pat) else realPatPath + + val scrutPDTs = ctx.gadt.addAllPDTsFrom(scrutPath) + val patPDTs = ctx.gadt.addAllPDTsFrom(patPath) + + if scrutPDTs.eq(null) || patPDTs.eq(null) then + ctx.gadt.restore(savedGadt) + return true + + val scrutSyms = Map.from { + scrutPDTs map { pdt => pdt.symbol.name -> pdt } + } + val patSyms = Map.from { + patPDTs map { pdt => pdt.symbol.name -> pdt } + } + + val shared = scrutSyms.keySet intersect patSyms.keySet + + val result = shared forall { name => + val tprS = scrutSyms(name) + val tprP = patSyms(name) + isSubType(tprS, tprP) && isSubType(tprP, tprS) + } + + if !result then + constraint = saved + ctx.gadt.restore(savedGadt) + // else + // if realScrutPath ne null then ctx.gadt.replacePath(scrutPath, realScrutPath) + // if realPatPath ne null then ctx.gadt.replacePath(patPath, realPatPath) + + result + } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index b568cb2c8af8..befeb517494c 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -117,6 +117,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym) protected def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = false) protected def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = true) + protected def gadtBounds(tpr: TypeRef)(using Context) = ctx.gadt.bounds(tpr) + protected def gadtAddLowerBound(tpr: TypeRef, b: Type): Boolean = ctx.gadt.addBound(tpr, b, isUpper = false) + protected def gadtAddUpperBound(tpr: TypeRef, b: Type): Boolean = ctx.gadt.addBound(tpr, b, isUpper = true) protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying @@ -179,6 +182,23 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private inline def inFrozenGadtAndConstraint[T](inline op: T): T = inFrozenGadtIf(true)(inFrozenConstraint(op)) + private def canRegisterPDT: Boolean = + ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint + + private def tryRegisterPDT(tpr: TypeRef): Boolean = + canRegisterPDT + && ctx.gadt.isConstrainablePDT(tpr) + && ctx.gadt.addPDT(tpr) + + extension (tpr: TypeRef) + private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean = + val gbounds = gadtBounds(tpr) match + case null => + if tryRegisterPDT(tpr) then gadtBounds(tpr) else null + case gbounds => gbounds + + gbounds != null && op(gbounds) + extension (sym: Symbol) private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean = val bounds = gadtBounds(sym) @@ -500,20 +520,23 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match { case info2: TypeBounds => - def compareGADT: Boolean = - tp2.symbol.onGadtBounds(gbounds2 => - isSubTypeWhenFrozen(tp1, gbounds2.lo) - || tp1.match - case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => - // Note: since we approximate constrained types only with their non-param bounds, - // we need to manually handle the case when we're comparing two constrained types, - // one of which is constrained to be a subtype of another. - // We do not need similar code in fourthTry, since we only need to care about - // comparing two constrained types, and that case will be handled here first. - ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) - case _ => false - || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) - && (isBottom(tp1) || GADTusage(tp2.symbol)) + def compareGADT: Boolean = tp2 match + case tp2: TypeRef => + tp2.onGadtBounds(gbounds2 => + isSubTypeWhenFrozen(tp1, gbounds2.lo) + || tp1.match + case tp1: TypeRef if ctx.gadt.contains(tp1) => + // Note: since we approximate constrained types only with their non-param bounds, + // we need to manually handle the case when we're comparing two constrained types, + // one of which is constrained to be a subtype of another. + // We do not need similar code in fourthTry, since we only need to care about + // comparing two constrained types, and that case will be handled here first. + ctx.gadt.isLess(tp1, tp2) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + case _ => false + || narrowGADTBounds(tp2, tp1, approx, isUpper = false) + ) && (isBottom(tp1) || GADTusage(tp2.symbol)) + case _ => false + end compareGADT isSubApproxHi(tp1, info2.lo) || compareGADT || tryLiftedToThis2 || fourthTry @@ -765,7 +788,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp1.info match { case TypeBounds(_, hi1) => def compareGADT = - tp1.symbol.onGadtBounds(gbounds1 => + tp1.onGadtBounds(gbounds1 => isSubTypeWhenFrozen(gbounds1.hi, tp2) || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) && (tp2.isAny || GADTusage(tp1.symbol)) @@ -1878,14 +1901,22 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling * `bound` as an upper or lower bound (which depends on `isUpper`). * Test that the resulting bounds are still satisfiable. */ - private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { + private def narrowGADTBounds(tr: TypeRef, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { val boundImprecise = approx.high || approx.low ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && { val tparam = tr.symbol - gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") - if (bound.isRef(tparam)) false - else if (isUpper) gadtAddUpperBound(tparam, bound) - else gadtAddLowerBound(tparam, bound) + gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam) && !ctx.gadt.isConstrainablePDT(bound)}") + + def registerPDTBound(): Boolean = bound match + case bound: TypeRef => + ctx.gadt.isConstrainablePDT(bound) && !ctx.gadt.contains(bound) && tryRegisterPDT(bound) + case _ => false + + registerPDTBound() + + if (bound.isRef(tparam) && !ctx.gadt.isConstrainablePDT(bound)) false + else if (isUpper) gadtAddUpperBound(tr, bound) + else gadtAddLowerBound(tr, bound) } } @@ -2800,6 +2831,9 @@ object TypeComparer { def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean = comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement)) + def constrainTypeMembers(scrut: Type, pat: Type, scrutPath: TermRef, patPath: TermRef)(using Context): Boolean = + comparing(_.constrainTypeMembers(scrut, pat, scrutPath, patPath)) + def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String = comparing(_.explained(op, header)) @@ -2829,16 +2863,31 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.gadtBounds(sym) } + override def gadtBounds(tpr: TypeRef)(using Context): TypeBounds = { + if (tpr.exists) footprint += tpr + super.gadtBounds(tpr) + } + override def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = { if (sym.exists) footprint += sym.typeRef super.gadtAddLowerBound(sym, b) } + override def gadtAddLowerBound(tpr: TypeRef, b: Type): Boolean = { + if (tpr.exists) footprint += tpr + super.gadtAddLowerBound(tpr, b) + } + override def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = { if (sym.exists) footprint += sym.typeRef super.gadtAddUpperBound(sym, b) } + override def gadtAddUpperBound(tpr: TypeRef, b: Type): Boolean = { + if (tpr.exists) footprint += tpr + super.gadtAddUpperBound(tpr, b) + } + override def typeVarInstance(tvar: TypeVar)(using Context): Type = { footprint += tvar super.typeVarInstance(tvar) diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index 75a5816c3164..c4128341878a 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -818,4 +818,7 @@ object TypeOps: def nestedPairs(ts: List[Type])(using Context): Type = ts.foldRight(defn.EmptyTupleModule.termRef: Type)(defn.PairClass.typeRef.appliedTo(_, _)) + def abstractTypeMemberSymbols(tp: Type)(using Context): List[Symbol] = + tp.abstractTypeMembers.toList map (_.symbol) + end TypeOps diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 6c13afa219b8..65574aa148af 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1633,6 +1633,23 @@ class Typer extends Namer } val pat1 = indexPattern(tree).transform(pat) val guard1 = typedExpr(tree.guard, defn.BooleanType) + + // Trigger path-dependent GADT reasoning + val scrutType = sel.tpe.widen + val patType = pat1.tpe.widen match + case AndType(tp1, tp2) => tp2 + case tp => tp + val scrutPath = sel.tpe match + case tp: TermRef => tp + case _ => null + val patPath = pat1.tpe match + case tp: TermRef => tp + case _ => null + + withMode(Mode.GadtConstraintInference) + (TypeComparer.constrainTypeMembers(scrutType, patType, scrutPath, patPath)) + (using gadtCtx) + var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1), pt1, ctx.scope.toList) if ctx.gadt.nonEmpty then // Store GADT constraint to later retrieve it (in PostTyper, for now). @@ -1647,6 +1664,7 @@ class Typer extends Namer } val pat1 = typedPattern(tree.pat, wideSelType)(using gadtCtx) + caseRest(pat1)( using Nullables.caseContext(sel, pat1)( using gadtCtx.fresh.setNewScope)) diff --git a/tests/neg/class-pdgadt.scala b/tests/neg/class-pdgadt.scala new file mode 100644 index 000000000000..83bc2f6ea1a8 --- /dev/null +++ b/tests/neg/class-pdgadt.scala @@ -0,0 +1,12 @@ +trait P: + case class A() + type B + +enum SUB[-A, +B]: + case EQ[X]() extends SUB[X, X] + +def f(p: P, e1: SUB[p.A, Int], e2: SUB[p.B, Int]) = e1 match + case SUB.EQ() => e2 match + case SUB.EQ() => + val t1: Int = ??? : p.A // error + val t2: Int = ??? : p.B diff --git a/tests/neg/necessary-pdgadt.scala b/tests/neg/necessary-pdgadt.scala new file mode 100644 index 000000000000..3be1f8337fa9 --- /dev/null +++ b/tests/neg/necessary-pdgadt.scala @@ -0,0 +1,11 @@ +/* N <: M */ +trait M +trait N + +enum SUB[-A, +B]: + case Ev[X]() extends SUB[X, X] +trait P { type T } + +def f(p: P, e: SUB[p.T, N | M]) = e match + case SUB.Ev() => + (??? : p.T) : N // error diff --git a/tests/neg/object-pdgadt.scala b/tests/neg/object-pdgadt.scala new file mode 100644 index 000000000000..af635ec08612 --- /dev/null +++ b/tests/neg/object-pdgadt.scala @@ -0,0 +1,9 @@ +object P: + type T + +enum SUB[-A, +B]: + case EQ[X]() extends SUB[X, X] + +def f(p: P.type, e: SUB[P.T, Int]) = p match + case _: P.type => + val t0: Int = ??? : p.T // error diff --git a/tests/neg/structural-gadt.scala b/tests/neg/structural-gadt.scala index 9a14881b5804..f7a9442b4cbb 100644 --- a/tests/neg/structural-gadt.scala +++ b/tests/neg/structural-gadt.scala @@ -11,19 +11,19 @@ object Test { def foo[A](e: Expr { type T = A }) = e match { case _: IntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: Expr { type T <: Int } => val a: A = 0 // error val i: Int = ??? : A // limitation // error case _: IntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: Expr { type T = Int } => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: Expr { type T <: A }) = e match { @@ -36,11 +36,11 @@ object Test { val i: Int = ??? : A // error case _: IntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: Expr { type T = Int } => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } diff --git a/tests/neg/structural-recursive-both1-gadt.scala b/tests/neg/structural-recursive-both1-gadt.scala index 97df59a92bb5..e0417044b670 100644 --- a/tests/neg/structural-recursive-both1-gadt.scala +++ b/tests/neg/structural-recursive-both1-gadt.scala @@ -28,19 +28,19 @@ object Test { def foo[A](e: IndirectExprExact[A]) = e match { case _: AltIndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: AltIndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: AltIndirectExprSub2[Int] => val a: A = 0 // error val i: Int = ??? : A // limitation // error case _: AltIndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: AltIndirectExprExact[Int] => val a: A = 0 // limitation // error diff --git a/tests/neg/structural-recursive-both2-gadt.scala b/tests/neg/structural-recursive-both2-gadt.scala index b58e05f3ed43..9aabef39a443 100644 --- a/tests/neg/structural-recursive-both2-gadt.scala +++ b/tests/neg/structural-recursive-both2-gadt.scala @@ -28,19 +28,19 @@ object Test { def foo[A](e: AltIndirectExprExact[A]) = e match { case _: IndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub2[Int] => val a: A = 0 // error val i: Int = ??? : A // limitation // error case _: IndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: IndirectExprExact[Int] => val a: A = 0 // limitation // error diff --git a/tests/neg/structural-recursive-pattern-gadt.scala b/tests/neg/structural-recursive-pattern-gadt.scala index ea7394b5b66b..4e1287e0b82c 100644 --- a/tests/neg/structural-recursive-pattern-gadt.scala +++ b/tests/neg/structural-recursive-pattern-gadt.scala @@ -28,23 +28,23 @@ object Test { def foo[A](e: ExprExact[A]) = e match { case _: IndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub2[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: ExprSub[A]) = e match { @@ -61,11 +61,11 @@ object Test { val i: Int = ??? : A // error case _: IndirectIntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } diff --git a/tests/neg/structural-recursive-scrutinee-gadt.scala b/tests/neg/structural-recursive-scrutinee-gadt.scala index cd4e2376f49a..ee8d8054576f 100644 --- a/tests/neg/structural-recursive-scrutinee-gadt.scala +++ b/tests/neg/structural-recursive-scrutinee-gadt.scala @@ -28,19 +28,19 @@ object Test { def foo[A](e: IndirectExprExact[A]) = e match { case _: IntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: ExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: ExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: IndirectExprSub[A]) = e match { diff --git a/tests/pos/basic-pdgadt.scala b/tests/pos/basic-pdgadt.scala new file mode 100644 index 000000000000..cd44954533ce --- /dev/null +++ b/tests/pos/basic-pdgadt.scala @@ -0,0 +1,20 @@ +enum SUB[-A, +B]: + case Ev[X]() extends SUB[X, X] + +trait Tag { type T } + +def f(p: Tag, e: SUB[Int, p.T]): p.T = e match + case SUB.Ev() => 0 + +def g(p: Tag, q: Tag, e: SUB[p.T, q.T]) = e match + case SUB.Ev() => + // p.T <: q.T + (??? : p.T) : q.T + +def h1[Q](p: Tag, e: SUB[p.T, Q]) = e match + case SUB.Ev() => + (??? : p.T) : Q + +def h2[P](q: Tag, e: SUB[P, q.T]) = e match + case SUB.Ev() => + (??? : P) : q.T diff --git a/tests/pos/i2941.scala b/tests/pos/i2941.scala index 83d58d2c35f2..d384cce47452 100644 --- a/tests/pos/i2941.scala +++ b/tests/pos/i2941.scala @@ -1,8 +1,8 @@ trait FooBase { type Bar >: Null <: BarBase { type This <: FooBase.this.Bar } - type This >: this.type <: FooBase { type This <: FooBase.this.This } + // type This >: this.type <: FooBase { type This <: FooBase.this.This } - def derived(bar: Bar): This = ??? + // def derived(bar: Bar): This = ??? } trait BarBase { @@ -11,7 +11,7 @@ trait BarBase { object Test { def bad(foo: FooBase): FooBase = foo match { - case foo: FooBase => - foo.derived(???) // Triggers infinite loop in TypeAssigner.avoid() + case foo1: FooBase => ??? + // foo1.derived(???) // Triggers infinite loop in TypeAssigner.avoid() } } diff --git a/tests/pos/necessary-pdgadt.scala b/tests/pos/necessary-pdgadt.scala new file mode 100644 index 000000000000..eb58f952ce65 --- /dev/null +++ b/tests/pos/necessary-pdgadt.scala @@ -0,0 +1,14 @@ +/* N <: M */ +trait M +trait N extends M + +enum SUB[-A, +B]: + case Ev[X]() extends SUB[X, X] + +trait P { type T } + +def f(p: P, e: SUB[p.T, N | M]) = e match + case SUB.Ev() => + // p.T <: M + (??? : p.T) : M + diff --git a/tests/pos/structural-pdt-pdgadt.scala b/tests/pos/structural-pdt-pdgadt.scala new file mode 100644 index 000000000000..7bf017c8a312 --- /dev/null +++ b/tests/pos/structural-pdt-pdgadt.scala @@ -0,0 +1,13 @@ +type typed[E <: Expr, V] = E & { type T = V } + +trait Expr { type T } +case class LitInt(x: Int) extends Expr { type T = Int } +case class Add(e1: Expr typed Int, e2: Expr typed Int) extends Expr { type T = Int } +case class LitBool(x: Boolean) extends Expr { type T = Boolean } +case class Or(e1: Expr typed Boolean, e2: Expr typed Boolean) extends Expr { type T = Boolean } + +def eval(e: Expr): e.T = e match + case LitInt(x) => x + case Add(e1, e2) => eval(e1) + eval(e2) + case LitBool(b) => b + case Or(e1, e2) => eval(e1) || eval(e2)