Skip to content

Commit 5c7a000

Browse files
committed
Add boundary break optimization tests
1 parent 211ddb7 commit 5c7a000

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package dotty.tools.backend.jvm
2+
3+
import scala.language.unsafeNulls
4+
5+
import org.junit.Assert._
6+
import org.junit.Test
7+
8+
import scala.tools.asm
9+
import asm._
10+
import asm.tree._
11+
12+
import scala.tools.asm.Opcodes
13+
import scala.jdk.CollectionConverters._
14+
import Opcodes._
15+
16+
class LabelBytecodeTests extends DottyBytecodeTest {
17+
import ASMConverters._
18+
19+
@Test def localLabelBreak = {
20+
testLabelBytecodeEquals(
21+
"""val local = boundary.Label[Long]()
22+
|try break(5L)(using local)
23+
|catch case ex: boundary.Break[Long] @unchecked =>
24+
| if ex.label eq local then ex.value
25+
| else throw ex
26+
""".stripMargin,
27+
"Long",
28+
Ldc(LDC, 5),
29+
Op(LRETURN)
30+
)
31+
}
32+
33+
@Test def simpleBoundaryBreak = {
34+
testLabelBytecodeEquals(
35+
"""boundary: l ?=>
36+
| break(2)(using l)
37+
""".stripMargin,
38+
"Int",
39+
Op(ICONST_2),
40+
Op(IRETURN)
41+
)
42+
43+
testLabelBytecodeEquals(
44+
"""boundary:
45+
| break(3)
46+
""".stripMargin,
47+
"Int",
48+
Op(ICONST_3),
49+
Op(IRETURN)
50+
)
51+
52+
testLabelBytecodeEquals(
53+
"""boundary:
54+
| break()
55+
""".stripMargin,
56+
"Unit",
57+
Op(RETURN)
58+
)
59+
}
60+
61+
@Test def labelExtraction = {
62+
// Test extra Inlined around the label
63+
testLabelBytecodeEquals(
64+
"""boundary:
65+
| break(2)(using summon[boundary.Label[Int]])
66+
""".stripMargin,
67+
"Int",
68+
Op(ICONST_2),
69+
Op(IRETURN)
70+
)
71+
72+
// Test extra Block around the label
73+
testLabelBytecodeEquals(
74+
"""boundary: l ?=>
75+
| break(2)(using { l })
76+
""".stripMargin,
77+
"Int",
78+
Op(ICONST_2),
79+
Op(IRETURN)
80+
)
81+
}
82+
83+
@Test def boundaryLocalBreak = {
84+
testLabelBytecodeExpect(
85+
"""val x: Boolean = true
86+
|boundary[Unit]:
87+
| var i = 0
88+
| while true do
89+
| i += 1
90+
| if i > 10 then break()
91+
""".stripMargin,
92+
"Unit",
93+
!throws(_)
94+
)
95+
}
96+
97+
@Test def boundaryNonLocalBreak = {
98+
testLabelBytecodeExpect(
99+
"""boundary[Unit]:
100+
| nonLocalBreak()
101+
""".stripMargin,
102+
"Unit",
103+
throws
104+
)
105+
106+
testLabelBytecodeExpect(
107+
"""boundary[Unit]:
108+
| def f() = break()
109+
| f()
110+
""".stripMargin,
111+
"Unit",
112+
throws
113+
)
114+
}
115+
116+
@Test def boundaryLocalAndNonLocalBreak = {
117+
testLabelBytecodeExpect(
118+
"""boundary[Unit]: l ?=>
119+
| break()
120+
| nonLocalBreak()
121+
""".stripMargin,
122+
"Unit",
123+
throws
124+
)
125+
}
126+
127+
private def throws(instructions: List[Instruction]): Boolean =
128+
instructions.exists {
129+
case Op(ATHROW) => true
130+
case _ => false
131+
}
132+
133+
private def testLabelBytecodeEquals(code: String, tpe: String, expected: Instruction*): Unit =
134+
checkLabelBytecodeInstructions(code, tpe) { instructions =>
135+
val expectedList = expected.toList
136+
assert(instructions == expectedList,
137+
"`test` was not properly generated\n" + diffInstructions(instructions, expectedList))
138+
}
139+
140+
private def testLabelBytecodeExpect(code: String, tpe: String, expected: List[Instruction] => Boolean): Unit =
141+
checkLabelBytecodeInstructions(code, tpe) { instructions =>
142+
assert(expected(instructions),
143+
"`test` was not properly generated\n" + instructions)
144+
}
145+
146+
private def checkLabelBytecodeInstructions(code: String, tpe: String)(checkOutput: List[Instruction] => Unit): Unit = {
147+
val source =
148+
s"""import scala.util.*
149+
|class Test:
150+
| def test: $tpe = {
151+
| ${code.lines().toList().asScala.mkString("", "\n ", "")}
152+
| }
153+
| def nonLocalBreak[T](value: T)(using boundary.Label[T]): Nothing = break(value)
154+
| def nonLocalBreak()(using boundary.Label[Unit]): Nothing = break(())
155+
""".stripMargin
156+
157+
checkBCode(source) { dir =>
158+
val clsIn = dir.lookupName("Test.class", directory = false).input
159+
val clsNode = loadClassNode(clsIn)
160+
val method = getMethod(clsNode, "test")
161+
162+
checkOutput(instructionsFromMethod(method))
163+
}
164+
}
165+
166+
}

0 commit comments

Comments
 (0)