Skip to content

Commit 0b1c9e3

Browse files
committed
Fix fragile transformation of fromProduct when using @unroll
UnrollDefinitions assumed that the body of `fromProduct` had a specific shape which is no longer the case with the dependent case class support introduced in the previous commit. This caused compiler crashes for tests/run/unroll-multiple.scala and tests/run/unroll-caseclass-integration This commit fixes this by directly generating the correct fromProduct in SyntheticMembers. This should also prevent crashes in situations where code is injected into existing trees like the code coverage support or external compiler plugins.
1 parent 4aa59eb commit 0b1c9e3

File tree

4 files changed

+62
-60
lines changed

4 files changed

+62
-60
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ class Definitions {
618618
@tu lazy val Int_== : Symbol = IntClass.requiredMethod(nme.EQ, List(IntType))
619619
@tu lazy val Int_>= : Symbol = IntClass.requiredMethod(nme.GE, List(IntType))
620620
@tu lazy val Int_<= : Symbol = IntClass.requiredMethod(nme.LE, List(IntType))
621+
@tu lazy val Int_> : Symbol = IntClass.requiredMethod(nme.GT, List(IntType))
621622
@tu lazy val LongType: TypeRef = valueTypeRef("scala.Long", java.lang.Long.TYPE, LongEnc, nme.specializedTypeNames.Long)
622623
def LongClass(using Context): ClassSymbol = LongType.symbol.asClass
623624
@tu lazy val Long_+ : Symbol = LongClass.requiredMethod(nme.PLUS, List(LongType))

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ object StdNames {
425425
val array_length : N = "array_length"
426426
val array_update : N = "array_update"
427427
val arraycopy: N = "arraycopy"
428+
val arity: N = "arity"
428429
val as: N = "as"
429430
val asTerm: N = "asTerm"
430431
val asModule: N = "asModule"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

+46-7
Original file line numberDiff line numberDiff line change
@@ -523,21 +523,47 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
523523
* ```
524524
* type MirroredMonoType = C[?]
525525
* ```
526+
*
527+
* However, if the last parameter is annotated `@unroll` then we generate:
528+
*
529+
* def fromProduct(x$0: Product): MirroredMonoType =
530+
* val arity = x$0.productArity
531+
* val a$1 = x$0.productElement(0).asInstanceOf[U]
532+
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
533+
* val c$1 = (
534+
* if arity > 2 then
535+
* x$0.productElement(2)
536+
* else
537+
* <default getter for the third parameter of C>
538+
* ).asInstanceOf[a$1.Elem]
539+
* new C[U](a$1, b$1, c$1*)
526540
*/
527541
def fromProductBody(caseClass: Symbol, productParam: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
528542
val classRef = optInfo match
529543
case Some(info) => TypeRef(info.pre, caseClass)
530544
case _ => caseClass.typeRef
531-
val (newPrefix, constrMeth) =
545+
val (newPrefix, constrMeth, constrSyms) =
532546
val constr = TermRef(classRef, caseClass.primaryConstructor)
547+
val symss = caseClass.primaryConstructor.paramSymss
533548
(constr.info: @unchecked) match
534549
case tl: PolyType =>
535550
val tvars = constrained(tl)
536551
val targs = for tvar <- tvars yield
537552
tvar.instantiate(fromBelow = false)
538-
(AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType])
553+
(AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType], symss(1))
539554
case mt: MethodType =>
540-
(classRef, mt)
555+
(classRef, mt, symss.head)
556+
557+
// Index of the first parameter marked `@unroll` or -1
558+
val unrolledFrom =
559+
constrSyms.indexWhere(_.hasAnnotation(defn.UnrollAnnot))
560+
561+
// `val arity = x$0.productArity`
562+
val arityDef: Option[ValDef] =
563+
if unrolledFrom != -1 then
564+
Some(SyntheticValDef(nme.arity, productParam.select(defn.Product_productArity)))
565+
else None
566+
val arityRefTree = arityDef.map(vd => ref(vd.symbol))
541567

542568
// Create symbols for the vals corresponding to each parameter
543569
// If there are dependent parameters, the infos won't be correct yet.
@@ -550,16 +576,29 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
550576
bindingSyms.foreach: bindingSym =>
551577
bindingSym.info = bindingSym.info.substParams(constrMeth, bindingRefs)
552578

579+
def defaultGetterAtIndex(idx: Int): Tree =
580+
val defaultGetterPrefix = caseClass.primaryConstructor.name.toTermName
581+
ref(caseClass.companionModule).select(NameKinds.DefaultGetterName(defaultGetterPrefix, idx))
582+
553583
val bindingDefs = bindingSyms.zipWithIndex.map: (bindingSym, idx) =>
554-
ValDef(bindingSym,
555-
productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
556-
.ensureConforms(bindingSym.info))
584+
val selection = productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
585+
val rhs = (
586+
if unrolledFrom != -1 && idx >= unrolledFrom then
587+
If(arityRefTree.get.select(defn.Int_>).appliedTo(Literal(Constant(idx))),
588+
thenp =
589+
selection,
590+
elsep =
591+
defaultGetterAtIndex(idx))
592+
else
593+
selection
594+
).ensureConforms(bindingSym.info)
595+
ValDef(bindingSym, rhs)
557596

558597
val newArgs = bindingRefs.lazyZip(constrMeth.paramInfos).map: (bindingRef, paramInfo) =>
559598
val refTree = ref(bindingRef)
560599
if paramInfo.isRepeatedParam then ctx.typer.seqToRepeated(refTree) else refTree
561600
Block(
562-
bindingDefs,
601+
arityDef.map(_ :: bindingDefs).getOrElse(bindingDefs),
563602
New(newPrefix, newArgs)
564603
)
565604
end fromProductBody

compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala

+14-53
Original file line numberDiff line numberDiff line change
@@ -228,36 +228,6 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
228228
forwarderDef
229229
}
230230

231-
private def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
232-
cpy.DefDef(defdef)(
233-
name = defdef.name,
234-
paramss = defdef.paramss,
235-
tpt = defdef.tpt,
236-
rhs = Match(
237-
ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")),
238-
startParamIndices.map { paramIndex =>
239-
val Apply(select, args) = defdef.rhs: @unchecked
240-
CaseDef(
241-
Literal(Constant(paramIndex)),
242-
EmptyTree,
243-
Apply(
244-
select,
245-
args.take(paramIndex) ++
246-
Range(paramIndex, paramCount).map(n =>
247-
ref(defdef.symbol.owner.companionModule)
248-
.select(DefaultGetterName(defdef.symbol.owner.primaryConstructor.name.toTermName, n))
249-
)
250-
)
251-
)
252-
} :+ CaseDef(
253-
Underscore(defn.IntType),
254-
EmptyTree,
255-
defdef.rhs
256-
)
257-
)
258-
).setDefTree
259-
}
260-
261231
private enum Gen:
262232
case Substitute(origin: Symbol, newDef: DefDef)
263233
case Forwarders(origin: Symbol, forwarders: List[DefDef])
@@ -277,38 +247,29 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
277247
val isCaseApply =
278248
defdef.name == nme.apply && defdef.symbol.owner.companionClass.is(CaseClass)
279249

280-
val isCaseFromProduct = defdef.name == nme.fromProduct && defdef.symbol.owner.companionClass.is(CaseClass)
281-
282250
val annotated =
283251
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
284252
else if (isCaseApply) defdef.symbol.owner.companionClass.primaryConstructor
285-
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
286253
else defdef.symbol
287254

288255
compute(annotated) match {
289256
case Nil => None
290257
case (paramClauseIndex, annotationIndices) :: Nil =>
291258
val paramCount = annotated.paramSymss(paramClauseIndex).size
292-
if isCaseFromProduct then
293-
Some(Gen.Substitute(
294-
origin = defdef.symbol,
295-
newDef = generateFromProduct(annotationIndices, paramCount, defdef)
296-
))
297-
else
298-
val generatedDefs =
299-
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
300-
indices.foldLeft(List.empty[DefDef]):
301-
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
302-
generateSingleForwarder(
303-
defdef,
304-
paramIndex,
305-
paramCount,
306-
nextParamIndex,
307-
paramClauseIndex,
308-
isCaseApply
309-
) :: defdefs
310-
case _ => unreachable("sliding with at least 2 elements")
311-
Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
259+
val generatedDefs =
260+
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
261+
indices.foldLeft(List.empty[DefDef]):
262+
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
263+
generateSingleForwarder(
264+
defdef,
265+
paramIndex,
266+
paramCount,
267+
nextParamIndex,
268+
paramClauseIndex,
269+
isCaseApply
270+
) :: defdefs
271+
case _ => unreachable("sliding with at least 2 elements")
272+
Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
312273

313274
case multiple =>
314275
report.error("Cannot have multiple parameter lists containing `@unroll` annotation", defdef.srcPos)

0 commit comments

Comments
 (0)