Skip to content

Commit 66f8992

Browse files
authored
Backport "Properly handle SAM types with wildcards" to LTS (#19112)
Backports #18201 to the LTS branch. PR submitted by the release tooling.
2 parents 8744018 + ea58b66 commit 66f8992

11 files changed

+215
-140
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ class Definitions {
744744
@tu lazy val StringContextModule_processEscapes: Symbol = StringContextModule.requiredMethod(nme.processEscapes)
745745

746746
@tu lazy val PartialFunctionClass: ClassSymbol = requiredClass("scala.PartialFunction")
747+
@tu lazy val PartialFunction_apply: Symbol = PartialFunctionClass.requiredMethod(nme.apply)
747748
@tu lazy val PartialFunction_isDefinedAt: Symbol = PartialFunctionClass.requiredMethod(nme.isDefinedAt)
748749
@tu lazy val PartialFunction_applyOrElse: Symbol = PartialFunctionClass.requiredMethod(nme.applyOrElse)
749750

compiler/src/dotty/tools/dotc/core/Types.scala

+128-82
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import CheckRealizable._
2121
import Variances.{Variance, setStructuralVariances, Invariant}
2222
import typer.Nullables
2323
import util.Stats._
24-
import util.SimpleIdentitySet
24+
import util.{SimpleIdentityMap, SimpleIdentitySet}
2525
import ast.tpd._
2626
import ast.TreeTypeMap
2727
import printing.Texts._
@@ -1741,7 +1741,7 @@ object Types {
17411741
t
17421742
case t if defn.isErasedFunctionType(t) =>
17431743
t
1744-
case t @ SAMType(_) =>
1744+
case t @ SAMType(_, _) =>
17451745
t
17461746
case _ =>
17471747
NoType
@@ -5497,104 +5497,119 @@ object Types {
54975497
* A type is a SAM type if it is a reference to a class or trait, which
54985498
*
54995499
* - has a single abstract method with a method type (ExprType
5500-
* and PolyType not allowed!) whose result type is not an implicit function type
5501-
* and which is not marked inline.
5500+
* and PolyType not allowed!) according to `possibleSamMethods`.
55025501
* - can be instantiated without arguments or with just () as argument.
55035502
*
5504-
* The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5505-
* type of the single abstract method.
5503+
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5504+
* type of the single abstract method and `samParent` is a subtype of the matched
5505+
* SAM type which has been stripped of wildcards to turn it into a valid parent
5506+
* type.
55065507
*/
55075508
object SAMType {
5508-
def zeroParamClass(tp: Type)(using Context): Type = tp match {
5509+
/** If possible, return a type which is both a subtype of `origTp` and a type
5510+
* application of `samClass` where none of the type arguments are
5511+
* wildcards (thus making it a valid parent type), otherwise return
5512+
* NoType.
5513+
*
5514+
* A wildcard in the original type will be replaced by its upper or lower bound in a way
5515+
* that maximizes the number of possible implementations of `samMeth`. For example,
5516+
* java.util.function defines an interface equivalent to:
5517+
*
5518+
* trait Function[T, R]:
5519+
* def apply(t: T): R
5520+
*
5521+
* and it usually appears with wildcards to compensate for the lack of
5522+
* definition-site variance in Java:
5523+
*
5524+
* (x => x.toInt): Function[? >: String, ? <: Int]
5525+
*
5526+
* When typechecking this lambda, we need to approximate the wildcards to find
5527+
* a valid parent type for our lambda to extend. We can see that in `apply`,
5528+
* `T` only appears contravariantly and `R` only appears covariantly, so by
5529+
* minimizing the first parameter and maximizing the second, we maximize the
5530+
* number of valid implementations of `apply` which lets us implement the lambda
5531+
* with a closure equivalent to:
5532+
*
5533+
* new Function[String, Int] { def apply(x: String): Int = x.toInt }
5534+
*
5535+
* If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5536+
* we arbitrarily pick the upper-bound.
5537+
*/
5538+
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
5539+
val tp = origTp.baseType(samClass)
5540+
if !(tp <:< origTp) then NoType
5541+
else tp match
5542+
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5543+
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5544+
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5545+
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5546+
vmap.recordLocalVariance(tp.symbol, variance)
5547+
case _ =>
5548+
foldOver(vmap, t)
5549+
val vmap = accu(VarianceMap.empty, samMeth.info)
5550+
val tparams = tycon.typeParamSymbols
5551+
val args1 = args.zipWithConserve(tparams):
5552+
case (arg @ TypeBounds(lo, hi), tparam) =>
5553+
val v = vmap.computedVariance(tparam)
5554+
if v.uncheckedNN < 0 then lo
5555+
else hi
5556+
case (arg, _) => arg
5557+
tp.derivedAppliedType(tycon, args1)
5558+
case _ =>
5559+
tp
5560+
5561+
def samClass(tp: Type)(using Context): Symbol = tp match
55095562
case tp: ClassInfo =>
5510-
def zeroParams(tp: Type): Boolean = tp.stripPoly match {
5563+
def zeroParams(tp: Type): Boolean = tp.stripPoly match
55115564
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
55125565
case et: ExprType => true
55135566
case _ => false
5514-
}
5515-
// `ContextFunctionN` does not have constructors
5516-
val ctor = tp.cls.primaryConstructor
5517-
if (!ctor.exists || zeroParams(ctor.info)) tp
5518-
else NoType
5567+
val cls = tp.cls
5568+
val validCtor =
5569+
val ctor = cls.primaryConstructor
5570+
// `ContextFunctionN` does not have constructors
5571+
!ctor.exists || zeroParams(ctor.info)
5572+
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
5573+
if validCtor && isInstantiable then tp.cls
5574+
else NoSymbol
55195575
case tp: AppliedType =>
5520-
zeroParamClass(tp.superType)
5576+
samClass(tp.superType)
55215577
case tp: TypeRef =>
5522-
zeroParamClass(tp.underlying)
5578+
samClass(tp.underlying)
55235579
case tp: RefinedType =>
5524-
zeroParamClass(tp.underlying)
5580+
samClass(tp.underlying)
55255581
case tp: TypeBounds =>
5526-
zeroParamClass(tp.underlying)
5582+
samClass(tp.underlying)
55275583
case tp: TypeVar =>
5528-
zeroParamClass(tp.underlying)
5584+
samClass(tp.underlying)
55295585
case tp: AnnotatedType =>
5530-
zeroParamClass(tp.underlying)
5531-
case _ =>
5532-
NoType
5533-
}
5534-
def isInstantiatable(tp: Type)(using Context): Boolean = zeroParamClass(tp) match {
5535-
case cinfo: ClassInfo if !cinfo.cls.isOneOf(FinalOrSealed) =>
5536-
val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5537-
tp <:< selfType
5586+
samClass(tp.underlying)
55385587
case _ =>
5539-
false
5540-
}
5541-
def unapply(tp: Type)(using Context): Option[MethodType] =
5542-
if (isInstantiatable(tp)) {
5543-
val absMems = tp.possibleSamMethods
5544-
if (absMems.size == 1)
5545-
absMems.head.info match {
5546-
case mt: MethodType if !mt.isParamDependent &&
5547-
!defn.isContextFunctionType(mt.resultType) =>
5548-
val cls = tp.classSymbol
5549-
5550-
// Given a SAM type such as:
5551-
//
5552-
// import java.util.function.Function
5553-
// Function[? >: String, ? <: Int]
5554-
//
5555-
// the single abstract method will have type:
5556-
//
5557-
// (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5558-
//
5559-
// which is not implementable outside of the scope of Function.
5560-
//
5561-
// To avoid this kind of issue, we approximate references to
5562-
// parameters of the SAM type by their bounds, this way in the
5563-
// above example we get:
5564-
//
5565-
// (x: String): Int
5566-
val approxParams = new ApproximatingTypeMap {
5567-
def apply(tp: Type): Type = tp match {
5568-
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) && tp.symbol.owner == cls =>
5569-
tp.info match {
5570-
case info: AliasingBounds =>
5571-
mapOver(info.alias)
5572-
case TypeBounds(lo, hi) =>
5573-
range(atVariance(-variance)(apply(lo)), apply(hi))
5574-
case _ =>
5575-
range(defn.NothingType, defn.AnyType) // should happen only in error cases
5576-
}
5577-
case _ =>
5578-
mapOver(tp)
5579-
}
5580-
}
5581-
val approx =
5582-
if ctx.owner.isContainedIn(cls) then mt
5583-
else approxParams(mt).asInstanceOf[MethodType]
5584-
Some(approx)
5588+
NoSymbol
5589+
5590+
def unapply(tp: Type)(using Context): Option[(MethodType, Type)] =
5591+
val cls = samClass(tp)
5592+
if cls.exists then
5593+
val absMems =
5594+
if tp.isRef(defn.PartialFunctionClass) then
5595+
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5596+
// pretending it is a SAM type. In the future it would be better to merge
5597+
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5598+
// def isDefinedAt(x: T) = true
5599+
// and overwrite that method whenever the function body is a sequence of
5600+
// case clauses.
5601+
List(defn.PartialFunction_apply)
5602+
else
5603+
tp.possibleSamMethods.map(_.symbol)
5604+
if absMems.lengthCompare(1) == 0 then
5605+
val samMethSym = absMems.head
5606+
val parent = samParent(tp, cls, samMethSym)
5607+
samMethSym.asSeenFrom(parent).info match
5608+
case mt: MethodType if !mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5609+
Some(mt, parent)
55855610
case _ =>
55865611
None
5587-
}
5588-
else if (tp isRef defn.PartialFunctionClass)
5589-
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5590-
// pretending it is a SAM type. In the future it would be better to merge
5591-
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5592-
// def isDefinedAt(x: T) = true
5593-
// and overwrite that method whenever the function body is a sequence of
5594-
// case clauses.
5595-
absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf[MethodType])
55965612
else None
5597-
}
55985613
else None
55995614
}
56005615

@@ -6427,6 +6442,37 @@ object Types {
64276442
}
64286443
}
64296444

6445+
object VarianceMap:
6446+
/** An immutable map representing the variance of keys of type `K` */
6447+
opaque type VarianceMap[K <: AnyRef] <: AnyRef = SimpleIdentityMap[K, Integer]
6448+
def empty[K <: AnyRef]: VarianceMap[K] = SimpleIdentityMap.empty[K]
6449+
extension [K <: AnyRef](vmap: VarianceMap[K])
6450+
/** The backing map used to implement this VarianceMap. */
6451+
inline def underlying: SimpleIdentityMap[K, Integer] = vmap
6452+
6453+
/** Return a new map taking into account that K appears in a
6454+
* {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6455+
*/
6456+
def recordLocalVariance(k: K, localVariance: Int): VarianceMap[K] =
6457+
val previousVariance = vmap(k)
6458+
if previousVariance == null then
6459+
vmap.updated(k, localVariance)
6460+
else if previousVariance == localVariance || previousVariance == 0 then
6461+
vmap
6462+
else
6463+
vmap.updated(k, 0)
6464+
6465+
/** Return the variance of `k`:
6466+
* - A positive value means that `k` appears only covariantly.
6467+
* - A negative value means that `k` appears only contravariantly.
6468+
* - A zero value means that `k` appears both covariantly and
6469+
* contravariantly, or appears invariantly.
6470+
* - A null value means that `k` does not appear at all.
6471+
*/
6472+
def computedVariance(k: K): Integer | Null =
6473+
vmap(k)
6474+
export VarianceMap.VarianceMap
6475+
64306476
// ----- Name Filters --------------------------------------------------
64316477

64326478
/** A name filter selects or discards a member name of a type `pre`.

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

+2-11
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class ExpandSAMs extends MiniPhase:
5050
tree // it's a plain function
5151
case tpe if defn.isContextFunctionType(tpe) =>
5252
tree
53-
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
53+
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
5454
val tpe1 = checkRefinements(tpe, fn)
5555
toPartialFunction(tree, tpe1)
56-
case tpe @ SAMType(_) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
56+
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
5757
checkRefinements(tpe, fn)
5858
tree
5959
case tpe =>
@@ -66,13 +66,6 @@ class ExpandSAMs extends MiniPhase:
6666
tree
6767
}
6868

69-
private def checkNoContextFunction(tpt: Tree)(using Context): Unit =
70-
if defn.isContextFunctionType(tpt.tpe) then
71-
report.error(
72-
em"""Implementation restriction: cannot convert this expression to
73-
|partial function with context function result type $tpt""",
74-
tpt.srcPos)
75-
7669
/** A partial function literal:
7770
*
7871
* ```
@@ -115,8 +108,6 @@ class ExpandSAMs extends MiniPhase:
115108
private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = {
116109
val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree: @unchecked
117110

118-
checkNoContextFunction(anon.tpt)
119-
120111
// The right hand side from which to construct the partial function. This is always a Match.
121112
// If the original rhs is already a Match (possibly in braces), return that.
122113
// Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure.

compiler/src/dotty/tools/dotc/typer/Applications.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ trait Applications extends Compatibility {
696696

697697
def SAMargOK =
698698
defn.isFunctionNType(argtpe1) && formal.match
699-
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
699+
case SAMType(samMeth, samParent) => argtpe <:< samMeth.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined))
700700
case _ => false
701701

702702
isCompatible(argtpe, formal)
@@ -2080,7 +2080,7 @@ trait Applications extends Compatibility {
20802080
* new java.io.ObjectOutputStream(f)
20812081
*/
20822082
pt match {
2083-
case SAMType(mtp) =>
2083+
case SAMType(mtp, _) =>
20842084
narrowByTypes(alts, mtp.paramInfos, mtp.resultType)
20852085
case _ =>
20862086
// pick any alternatives that are not methods since these might be convertible

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+9-16
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ object Inferencing {
411411
val vs = variances(tp)
412412
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
413413
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
414-
vs foreachBinding { (tvar, v) =>
414+
vs.underlying foreachBinding { (tvar, v) =>
415415
if !tvar.isInstantiated then
416416
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
417417
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
@@ -444,8 +444,6 @@ object Inferencing {
444444
res
445445
}
446446

447-
type VarianceMap = SimpleIdentityMap[TypeVar, Integer]
448-
449447
/** All occurrences of type vars in `tp` that satisfy predicate
450448
* `include` mapped to their variances (-1/0/1) in both `tp` and
451449
* `pt.finalResultType`, where
@@ -469,23 +467,18 @@ object Inferencing {
469467
*
470468
* we want to instantiate U to x.type right away. No need to wait further.
471469
*/
472-
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap = {
470+
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
473471
Stats.record("variances")
474472
val constraint = ctx.typerState.constraint
475473

476-
object accu extends TypeAccumulator[VarianceMap] {
474+
object accu extends TypeAccumulator[VarianceMap[TypeVar]]:
477475
def setVariance(v: Int) = variance = v
478-
def apply(vmap: VarianceMap, t: Type): VarianceMap = t match {
476+
def apply(vmap: VarianceMap[TypeVar], t: Type): VarianceMap[TypeVar] = t match
479477
case t: TypeVar
480478
if !t.isInstantiated && accCtx.typerState.constraint.contains(t) =>
481-
val v = vmap(t)
482-
if (v == null) vmap.updated(t, variance)
483-
else if (v == variance || v == 0) vmap
484-
else vmap.updated(t, 0)
479+
vmap.recordLocalVariance(t, variance)
485480
case _ =>
486481
foldOver(vmap, t)
487-
}
488-
}
489482

490483
/** Include in `vmap` type variables occurring in the constraints of type variables
491484
* already in `vmap`. Specifically:
@@ -497,10 +490,10 @@ object Inferencing {
497490
* bounds as non-variant.
498491
* Do this in a fixpoint iteration until `vmap` stabilizes.
499492
*/
500-
def propagate(vmap: VarianceMap): VarianceMap = {
493+
def propagate(vmap: VarianceMap[TypeVar]): VarianceMap[TypeVar] = {
501494
var vmap1 = vmap
502495
def traverse(tp: Type) = { vmap1 = accu(vmap1, tp) }
503-
vmap.foreachBinding { (tvar, v) =>
496+
vmap.underlying.foreachBinding { (tvar, v) =>
504497
val param = tvar.origin
505498
constraint.entry(param) match
506499
case TypeBounds(lo, hi) =>
@@ -516,7 +509,7 @@ object Inferencing {
516509
if (vmap1 eq vmap) vmap else propagate(vmap1)
517510
}
518511

519-
propagate(accu(accu(SimpleIdentityMap.empty, tp), pt.finalResultType))
512+
propagate(accu(accu(VarianceMap.empty, tp), pt.finalResultType))
520513
}
521514

522515
/** Run the transformation after dealiasing but return the original type if it was a no-op. */
@@ -642,7 +635,7 @@ trait Inferencing { this: Typer =>
642635
if !tvar.isInstantiated then
643636
// isInstantiated needs to be checked again, since previous interpolations could already have
644637
// instantiated `tvar` through unification.
645-
val v = vs(tvar)
638+
val v = vs.computedVariance(tvar)
646639
if v == null then buf += ((tvar, 0))
647640
else if v.intValue != 0 then buf += ((tvar, v.intValue))
648641
else comparing(cmp =>

0 commit comments

Comments
 (0)