@@ -21,7 +21,7 @@ import CheckRealizable._
21
21
import Variances .{Variance , setStructuralVariances , Invariant }
22
22
import typer .Nullables
23
23
import util .Stats ._
24
- import util .SimpleIdentitySet
24
+ import util .{ SimpleIdentityMap , SimpleIdentitySet }
25
25
import ast .tpd ._
26
26
import ast .TreeTypeMap
27
27
import printing .Texts ._
@@ -1741,7 +1741,7 @@ object Types {
1741
1741
t
1742
1742
case t if defn.isErasedFunctionType(t) =>
1743
1743
t
1744
- case t @ SAMType (_) =>
1744
+ case t @ SAMType (_, _ ) =>
1745
1745
t
1746
1746
case _ =>
1747
1747
NoType
@@ -5497,104 +5497,119 @@ object Types {
5497
5497
* A type is a SAM type if it is a reference to a class or trait, which
5498
5498
*
5499
5499
* - has a single abstract method with a method type (ExprType
5500
- * and PolyType not allowed!) whose result type is not an implicit function type
5501
- * and which is not marked inline.
5500
+ * and PolyType not allowed!) according to `possibleSamMethods`.
5502
5501
* - can be instantiated without arguments or with just () as argument.
5503
5502
*
5504
- * The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5505
- * type of the single abstract method.
5503
+ * The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5504
+ * type of the single abstract method and `samParent` is a subtype of the matched
5505
+ * SAM type which has been stripped of wildcards to turn it into a valid parent
5506
+ * type.
5506
5507
*/
5507
5508
object SAMType {
5508
- def zeroParamClass (tp : Type )(using Context ): Type = tp match {
5509
+ /** If possible, return a type which is both a subtype of `origTp` and a type
5510
+ * application of `samClass` where none of the type arguments are
5511
+ * wildcards (thus making it a valid parent type), otherwise return
5512
+ * NoType.
5513
+ *
5514
+ * A wildcard in the original type will be replaced by its upper or lower bound in a way
5515
+ * that maximizes the number of possible implementations of `samMeth`. For example,
5516
+ * java.util.function defines an interface equivalent to:
5517
+ *
5518
+ * trait Function[T, R]:
5519
+ * def apply(t: T): R
5520
+ *
5521
+ * and it usually appears with wildcards to compensate for the lack of
5522
+ * definition-site variance in Java:
5523
+ *
5524
+ * (x => x.toInt): Function[? >: String, ? <: Int]
5525
+ *
5526
+ * When typechecking this lambda, we need to approximate the wildcards to find
5527
+ * a valid parent type for our lambda to extend. We can see that in `apply`,
5528
+ * `T` only appears contravariantly and `R` only appears covariantly, so by
5529
+ * minimizing the first parameter and maximizing the second, we maximize the
5530
+ * number of valid implementations of `apply` which lets us implement the lambda
5531
+ * with a closure equivalent to:
5532
+ *
5533
+ * new Function[String, Int] { def apply(x: String): Int = x.toInt }
5534
+ *
5535
+ * If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5536
+ * we arbitrarily pick the upper-bound.
5537
+ */
5538
+ def samParent (origTp : Type , samClass : Symbol , samMeth : Symbol )(using Context ): Type =
5539
+ val tp = origTp.baseType(samClass)
5540
+ if ! (tp <:< origTp) then NoType
5541
+ else tp match
5542
+ case tp @ AppliedType (tycon, args) if tp.hasWildcardArg =>
5543
+ val accu = new TypeAccumulator [VarianceMap [Symbol ]]:
5544
+ def apply (vmap : VarianceMap [Symbol ], t : Type ): VarianceMap [Symbol ] = t match
5545
+ case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) =>
5546
+ vmap.recordLocalVariance(tp.symbol, variance)
5547
+ case _ =>
5548
+ foldOver(vmap, t)
5549
+ val vmap = accu(VarianceMap .empty, samMeth.info)
5550
+ val tparams = tycon.typeParamSymbols
5551
+ val args1 = args.zipWithConserve(tparams):
5552
+ case (arg @ TypeBounds (lo, hi), tparam) =>
5553
+ val v = vmap.computedVariance(tparam)
5554
+ if v.uncheckedNN < 0 then lo
5555
+ else hi
5556
+ case (arg, _) => arg
5557
+ tp.derivedAppliedType(tycon, args1)
5558
+ case _ =>
5559
+ tp
5560
+
5561
+ def samClass (tp : Type )(using Context ): Symbol = tp match
5509
5562
case tp : ClassInfo =>
5510
- def zeroParams (tp : Type ): Boolean = tp.stripPoly match {
5563
+ def zeroParams (tp : Type ): Boolean = tp.stripPoly match
5511
5564
case mt : MethodType => mt.paramInfos.isEmpty && ! mt.resultType.isInstanceOf [MethodType ]
5512
5565
case et : ExprType => true
5513
5566
case _ => false
5514
- }
5515
- // `ContextFunctionN` does not have constructors
5516
- val ctor = tp.cls.primaryConstructor
5517
- if (! ctor.exists || zeroParams(ctor.info)) tp
5518
- else NoType
5567
+ val cls = tp.cls
5568
+ val validCtor =
5569
+ val ctor = cls.primaryConstructor
5570
+ // `ContextFunctionN` does not have constructors
5571
+ ! ctor.exists || zeroParams(ctor.info)
5572
+ val isInstantiable = ! cls.isOneOf(FinalOrSealed ) && (tp.appliedRef <:< tp.selfType)
5573
+ if validCtor && isInstantiable then tp.cls
5574
+ else NoSymbol
5519
5575
case tp : AppliedType =>
5520
- zeroParamClass (tp.superType)
5576
+ samClass (tp.superType)
5521
5577
case tp : TypeRef =>
5522
- zeroParamClass (tp.underlying)
5578
+ samClass (tp.underlying)
5523
5579
case tp : RefinedType =>
5524
- zeroParamClass (tp.underlying)
5580
+ samClass (tp.underlying)
5525
5581
case tp : TypeBounds =>
5526
- zeroParamClass (tp.underlying)
5582
+ samClass (tp.underlying)
5527
5583
case tp : TypeVar =>
5528
- zeroParamClass (tp.underlying)
5584
+ samClass (tp.underlying)
5529
5585
case tp : AnnotatedType =>
5530
- zeroParamClass(tp.underlying)
5531
- case _ =>
5532
- NoType
5533
- }
5534
- def isInstantiatable (tp : Type )(using Context ): Boolean = zeroParamClass(tp) match {
5535
- case cinfo : ClassInfo if ! cinfo.cls.isOneOf(FinalOrSealed ) =>
5536
- val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5537
- tp <:< selfType
5586
+ samClass(tp.underlying)
5538
5587
case _ =>
5539
- false
5540
- }
5541
- def unapply (tp : Type )(using Context ): Option [MethodType ] =
5542
- if (isInstantiatable(tp)) {
5543
- val absMems = tp.possibleSamMethods
5544
- if (absMems.size == 1 )
5545
- absMems.head.info match {
5546
- case mt : MethodType if ! mt.isParamDependent &&
5547
- ! defn.isContextFunctionType(mt.resultType) =>
5548
- val cls = tp.classSymbol
5549
-
5550
- // Given a SAM type such as:
5551
- //
5552
- // import java.util.function.Function
5553
- // Function[? >: String, ? <: Int]
5554
- //
5555
- // the single abstract method will have type:
5556
- //
5557
- // (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5558
- //
5559
- // which is not implementable outside of the scope of Function.
5560
- //
5561
- // To avoid this kind of issue, we approximate references to
5562
- // parameters of the SAM type by their bounds, this way in the
5563
- // above example we get:
5564
- //
5565
- // (x: String): Int
5566
- val approxParams = new ApproximatingTypeMap {
5567
- def apply (tp : Type ): Type = tp match {
5568
- case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) && tp.symbol.owner == cls =>
5569
- tp.info match {
5570
- case info : AliasingBounds =>
5571
- mapOver(info.alias)
5572
- case TypeBounds (lo, hi) =>
5573
- range(atVariance(- variance)(apply(lo)), apply(hi))
5574
- case _ =>
5575
- range(defn.NothingType , defn.AnyType ) // should happen only in error cases
5576
- }
5577
- case _ =>
5578
- mapOver(tp)
5579
- }
5580
- }
5581
- val approx =
5582
- if ctx.owner.isContainedIn(cls) then mt
5583
- else approxParams(mt).asInstanceOf [MethodType ]
5584
- Some (approx)
5588
+ NoSymbol
5589
+
5590
+ def unapply (tp : Type )(using Context ): Option [(MethodType , Type )] =
5591
+ val cls = samClass(tp)
5592
+ if cls.exists then
5593
+ val absMems =
5594
+ if tp.isRef(defn.PartialFunctionClass ) then
5595
+ // To maintain compatibility with 2.x, we treat PartialFunction specially,
5596
+ // pretending it is a SAM type. In the future it would be better to merge
5597
+ // Function and PartialFunction, have Function1 contain a isDefinedAt method
5598
+ // def isDefinedAt(x: T) = true
5599
+ // and overwrite that method whenever the function body is a sequence of
5600
+ // case clauses.
5601
+ List (defn.PartialFunction_apply )
5602
+ else
5603
+ tp.possibleSamMethods.map(_.symbol)
5604
+ if absMems.lengthCompare(1 ) == 0 then
5605
+ val samMethSym = absMems.head
5606
+ val parent = samParent(tp, cls, samMethSym)
5607
+ samMethSym.asSeenFrom(parent).info match
5608
+ case mt : MethodType if ! mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5609
+ Some (mt, parent)
5585
5610
case _ =>
5586
5611
None
5587
- }
5588
- else if (tp isRef defn.PartialFunctionClass )
5589
- // To maintain compatibility with 2.x, we treat PartialFunction specially,
5590
- // pretending it is a SAM type. In the future it would be better to merge
5591
- // Function and PartialFunction, have Function1 contain a isDefinedAt method
5592
- // def isDefinedAt(x: T) = true
5593
- // and overwrite that method whenever the function body is a sequence of
5594
- // case clauses.
5595
- absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf [MethodType ])
5596
5612
else None
5597
- }
5598
5613
else None
5599
5614
}
5600
5615
@@ -6427,6 +6442,37 @@ object Types {
6427
6442
}
6428
6443
}
6429
6444
6445
+ object VarianceMap :
6446
+ /** An immutable map representing the variance of keys of type `K` */
6447
+ opaque type VarianceMap [K <: AnyRef ] <: AnyRef = SimpleIdentityMap [K , Integer ]
6448
+ def empty [K <: AnyRef ]: VarianceMap [K ] = SimpleIdentityMap .empty[K ]
6449
+ extension [K <: AnyRef ](vmap : VarianceMap [K ])
6450
+ /** The backing map used to implement this VarianceMap. */
6451
+ inline def underlying : SimpleIdentityMap [K , Integer ] = vmap
6452
+
6453
+ /** Return a new map taking into account that K appears in a
6454
+ * {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6455
+ */
6456
+ def recordLocalVariance (k : K , localVariance : Int ): VarianceMap [K ] =
6457
+ val previousVariance = vmap(k)
6458
+ if previousVariance == null then
6459
+ vmap.updated(k, localVariance)
6460
+ else if previousVariance == localVariance || previousVariance == 0 then
6461
+ vmap
6462
+ else
6463
+ vmap.updated(k, 0 )
6464
+
6465
+ /** Return the variance of `k`:
6466
+ * - A positive value means that `k` appears only covariantly.
6467
+ * - A negative value means that `k` appears only contravariantly.
6468
+ * - A zero value means that `k` appears both covariantly and
6469
+ * contravariantly, or appears invariantly.
6470
+ * - A null value means that `k` does not appear at all.
6471
+ */
6472
+ def computedVariance (k : K ): Integer | Null =
6473
+ vmap(k)
6474
+ export VarianceMap .VarianceMap
6475
+
6430
6476
// ----- Name Filters --------------------------------------------------
6431
6477
6432
6478
/** A name filter selects or discards a member name of a type `pre`.
0 commit comments