Skip to content

Commit aaf9ec7

Browse files
KuceraMartindwijnand
authored andcommitted
List(...) optimization to avoid intermediate array (closes #17035)
1 parent 4eae174 commit aaf9ec7

File tree

4 files changed

+82
-11
lines changed

4 files changed

+82
-11
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+11-9
Original file line numberDiff line numberDiff line change
@@ -517,14 +517,15 @@ class Definitions {
517517
methodNames.map(getWrapVarargsArrayModule.requiredMethod(_))
518518
})
519519

520-
@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
521-
def ListType: TypeRef = ListClass.typeRef
522-
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
523-
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
524-
def NilType: TermRef = NilModule.termRef
525-
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
526-
def ConsType: TypeRef = ConsClass.typeRef
527-
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
520+
@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
521+
def ListType: TypeRef = ListClass.typeRef
522+
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
523+
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
524+
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
525+
def NilType: TermRef = NilModule.termRef
526+
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
527+
def ConsType: TypeRef = ConsClass.typeRef
528+
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
528529

529530
@tu lazy val SingletonClass: ClassSymbol =
530531
// needed as a synthetic class because Scala 2.x refers to it in classfiles
@@ -543,7 +544,8 @@ class Definitions {
543544
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
544545
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
545546
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
546-
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
547+
@tu lazy val SeqModule : Symbol = requiredModule("scala.collection.immutable.Seq")
548+
@tu lazy val SeqModule_apply : Symbol = SeqModule.requiredMethod(nme.apply)
547549

548550

549551
@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")

compiler/src/dotty/tools/dotc/transform/ArrayApply.scala

+25-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,18 @@ class ArrayApply extends MiniPhase {
2222

2323
override def description: String = ArrayApply.description
2424

25+
private var transformListApplyLimit = 8
26+
27+
private def reducingTransformListApply[A](depth: Int)(body: => A): A = {
28+
val saved = transformListApplyLimit
29+
transformListApplyLimit -= depth
30+
try body
31+
finally transformListApplyLimit = saved
32+
}
33+
2534
override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
2635
if isArrayModuleApply(tree.symbol) then
27-
tree.args match {
36+
tree.args match
2837
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
2938
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
3039
seqLit
@@ -35,14 +44,28 @@ class ArrayApply extends MiniPhase {
3544

3645
case _ =>
3746
tree
38-
}
47+
48+
else if isListOrSeqModuleApply(tree.symbol) then
49+
tree.args match
50+
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
51+
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: tpd.JavaSeqLiteral)))) :: Nil
52+
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) &&
53+
rest.elems.lengthIs < transformListApplyLimit =>
54+
rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) =>
55+
tpd.New(defn.ConsType, List(elem, acc))
56+
57+
case _ =>
58+
tree
3959

4060
else tree
4161

4262
private def isArrayModuleApply(sym: Symbol)(using Context): Boolean =
4363
sym.name == nme.apply
4464
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))
4565

66+
private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean =
67+
sym == defn.ListModule_apply || sym == defn.SeqModule_apply
68+
4669
/** Only optimize when classtag if it is one of
4770
* - `ClassTag.apply(classOf[XYZ])`
4871
* - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ``

compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala

+25
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,29 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
160160
}
161161
}
162162

163+
@Test def testListApplyAvoidsIntermediateArray = {
164+
val source =
165+
"""
166+
|class Foo {
167+
| def meth1: List[String] = List("1", "2", "3")
168+
| def meth2: List[String] =
169+
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
170+
|}
171+
""".stripMargin
172+
173+
checkBCode(source) { dir =>
174+
val clsIn = dir.lookupName("Foo.class", directory = false).input
175+
val clsNode = loadClassNode(clsIn)
176+
val meth1 = getMethod(clsNode, "meth1")
177+
val meth2 = getMethod(clsNode, "meth2")
178+
179+
val instructions1 = instructionsFromMethod(meth1)
180+
val instructions2 = instructionsFromMethod(meth2)
181+
182+
assert(instructions1 == instructions2,
183+
"the List.apply method " +
184+
diffInstructions(instructions1, instructions2))
185+
}
186+
}
187+
163188
}

tests/run/list-apply-eval.scala

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
object Test:
2+
3+
var counter = 0
4+
5+
def next =
6+
counter += 1
7+
counter.toString
8+
9+
def main(args: Array[String]): Unit =
10+
//List.apply is subject to an optimisation in cleanup
11+
//ensure that the arguments are evaluated in the currect order
12+
// Rewritten to:
13+
// val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil)));
14+
val myList = List(next, next, next)
15+
assert(myList == List("1", "2", "3"), myList)
16+
17+
val mySeq = Seq(next, next, next)
18+
assert(mySeq == Seq("4", "5", "6"), mySeq)
19+
20+
val emptyList = List[Int]()
21+
assert(emptyList == Nil)

0 commit comments

Comments
 (0)