Skip to content

Commit 8439370

Browse files
committed
Allow to beta reduce curried function applications in quotes reflect
Previously, the curried functions with multiple applications were not able to be beta-reduced in any way, which was unexpected. Now we allow reducing any number of top-level function applications for a curried function. This was also made clearer in the documentation for the affected (Expr.betaReduce and Term.betaReduce) methods.
1 parent a6c40b1 commit 8439370

File tree

5 files changed

+183
-21
lines changed

5 files changed

+183
-21
lines changed

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

+16-11
Original file line numberDiff line numberDiff line change
@@ -396,17 +396,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
396396
end TermTypeTest
397397

398398
object Term extends TermModule:
399-
def betaReduce(tree: Term): Option[Term] =
400-
tree match
401-
case tpd.Block(Nil, expr) =>
402-
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
403-
case tpd.Inlined(_, Nil, expr) =>
404-
betaReduce(expr)
405-
case _ =>
406-
val tree1 = dotc.transform.BetaReduce(tree)
407-
if tree1 eq tree then None
408-
else Some(tree1.withSpan(tree.span))
409-
399+
def betaReduce(tree: Term): Option[Term] =
400+
val tree1 = new dotty.tools.dotc.ast.tpd.TreeMap {
401+
override def transform(tree: Tree)(using Context): Tree = tree match {
402+
case tpd.Block(Nil, _) | tpd.Inlined(_, Nil, _) =>
403+
super.transform(tree)
404+
case tpd.Apply(sel @ tpd.Select(expr, nme), args) =>
405+
val tree1 = cpy.Apply(tree)(cpy.Select(sel)(transform(expr), nme), args)
406+
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
407+
case tpd.Apply(ta @ tpd.TypeApply(sel @ tpd.Select(expr: Apply, nme), tpts), args) =>
408+
val tree1 = cpy.Apply(tree)(cpy.TypeApply(ta)(cpy.Select(sel)(transform(expr), nme), tpts), args)
409+
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
410+
case _ =>
411+
dotc.transform.BetaReduce(tree).withSpan(tree.span)
412+
}
413+
}.transform(tree)
414+
if tree1 == tree then None else Some(tree1)
410415
end Term
411416

412417
given TermMethods: TermMethods with

library/src/scala/quoted/Expr.scala

+38-5
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,45 @@ abstract class Expr[+T] private[scala] ()
1010
object Expr {
1111

1212
/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
13-
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
14-
* then it optimizes this the top most call by returning the result of beta-reducing the application.
15-
* Otherwise returns `expr`.
13+
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
14+
* then it optimizes the top most call by returning the result of beta-reducing the application.
15+
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
16+
* Otherwise returns `expr`.
1617
*
17-
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
18-
* Some bindings may be elided as an early optimization.
18+
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
19+
* Some bindings may be elided as an early optimization.
20+
*
21+
* Example:
22+
* ```scala sc:nocompile
23+
* ((a: Int, b: Int) => a + b).apply(x, y)
24+
* ```
25+
* will be reduced to
26+
* ```scala sc:nocompile
27+
* val a = x
28+
* val b = y
29+
* a + b
30+
* ```
31+
*
32+
* Generally:
33+
* ```scala sc:nocompile
34+
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
35+
* ```
36+
* will be reduced to
37+
* ```scala sc:nocompile
38+
* type X1 = Tx1
39+
* type Y1 = Ty1
40+
* ...
41+
* val x1 = myX1
42+
* val y1 = myY1
43+
* ...
44+
* type Xn = Txn
45+
* type Yn = Tyn
46+
* ...
47+
* val xn = myXn
48+
* val yn = myYn
49+
* ...
50+
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
51+
* ```
1952
*/
2053
def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] =
2154
import quotes.reflect.*

library/src/scala/quoted/Quotes.scala

+38-5
Original file line numberDiff line numberDiff line change
@@ -774,14 +774,47 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
774774
/** Methods of the module object `val Term` */
775775
trait TermModule { this: Term.type =>
776776

777-
/** Returns a term that is functionally equivalent to `t`,
777+
/** Returns a term that is functionally equivalent to `t`,
778778
* however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
779-
* then it optimizes this the top most call by returning the `Some`
780-
* with the result of beta-reducing the application.
779+
* then it optimizes the top most call by returning `Some`
780+
* with the result of beta-reducing the function application.
781+
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
781782
* Otherwise returns `None`.
782783
*
783-
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
784-
* Some bindings may be elided as an early optimization.
784+
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
785+
* Some bindings may be elided as an early optimization.
786+
*
787+
* Example:
788+
* ```scala sc:nocompile
789+
* ((a: Int, b: Int) => a + b).apply(x, y)
790+
* ```
791+
* will be reduced to
792+
* ```scala sc:nocompile
793+
* val a = x
794+
* val b = y
795+
* a + b
796+
* ```
797+
*
798+
* Generally:
799+
* ```scala sc:nocompile
800+
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
801+
* ```
802+
* will be reduced to
803+
* ```scala sc:nocompile
804+
* type X1 = Tx1
805+
* type Y1 = Ty1
806+
* ...
807+
* val x1 = myX1
808+
* val y1 = myY1
809+
* ...
810+
* type Xn = Txn
811+
* type Yn = Tyn
812+
* ...
813+
* val xn = myXn
814+
* val yn = myYn
815+
* ...
816+
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
817+
* ```
785818
*/
786819
def betaReduce(term: Term): Option[Term]
787820

tests/pos-macros/i17506/Macro_1.scala

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
class Foo
2+
class Bar
3+
class Baz
4+
5+
import scala.quoted._
6+
7+
def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term =
8+
import quotes.reflect._
9+
val reducedMaybe = Term.betaReduce(applied.asTerm)
10+
assert(reducedMaybe.isDefined)
11+
val reduced = reducedMaybe.get
12+
assert(reduced.show == expected,s"obtained: ${reduced.show}, expected: ${expected}")
13+
reduced
14+
15+
inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit =
16+
${regularCurriedCtxFun2BetaReduceTestImpl('f)}
17+
def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] =
18+
val expected =
19+
"""|{
20+
| val contextual$3: Bar = new Bar()
21+
| val contextual$2: Foo = new Foo()
22+
| 123
23+
|}""".stripMargin
24+
val applied = '{$f(using new Foo())(using new Bar())}
25+
assertBetaReduction(applied, expected).asExprOf[Int]
26+
27+
inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int =
28+
${regularCurriedFun2BetaReduceTestImpl('f)}
29+
def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] =
30+
val expected =
31+
"""|{
32+
| val b: Bar = new Bar()
33+
| val f: Foo = new Foo()
34+
| 123
35+
|}""".stripMargin
36+
val applied = '{$f(new Foo())(new Bar())}
37+
assertBetaReduction(applied, expected).asExprOf[Int]
38+
39+
inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit =
40+
${typeParamCurriedFun2BetaReduceTestImpl('f)}
41+
def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] =
42+
val expected =
43+
"""|{
44+
| type Y = Bar
45+
| val y: Bar = new Bar()
46+
| type X = Foo
47+
| val x: Foo = new Foo()
48+
| typeParamFun2[Y, X](y, x)
49+
|}""".stripMargin
50+
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())}
51+
assertBetaReduction(applied, expected).asExprOf[Unit]
52+
53+
inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int =
54+
${regularCurriedFun3BetaReduceTestImpl('f)}
55+
def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] =
56+
val expected =
57+
"""|{
58+
| val i: Baz = new Baz()
59+
| val b: Bar = new Bar()
60+
| val f: Foo = new Foo()
61+
| 123
62+
|}""".stripMargin
63+
val applied = '{$f(new Foo())(new Bar())(new Baz())}
64+
assertBetaReduction(applied, expected).asExprOf[Int]
65+
66+
inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit =
67+
${typeParamCurriedFun3BetaReduceTestImpl('f)}
68+
def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] =
69+
val expected =
70+
"""|{
71+
| type Z = Baz
72+
| val z: Baz = new Baz()
73+
| type Y = Bar
74+
| val y: Bar = new Bar()
75+
| type X = Foo
76+
| val x: Foo = new Foo()
77+
| typeParamFun3[Z, Y, X](z, y, x)
78+
|}""".stripMargin
79+
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())}
80+
assertBetaReduction(applied, expected).asExprOf[Unit]

tests/pos-macros/i17506/Test_2.scala

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@main def run() =
2+
def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString)
3+
def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString)
4+
5+
regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123)
6+
regularCurriedCtxFun2BetaReduceTest(123)
7+
regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123))
8+
typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x))
9+
10+
regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123)
11+
typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))

0 commit comments

Comments
 (0)