Skip to content

Commit f3dda82

Browse files
committed
Allow to beta reduce curried function applications in quotes reflect
Previously, the curried functions with multiple applications were not able to be beta-reduced in any way, which was unexpected. Now we allow reducing any number of top-level function applications for a curried function. This was also made clearer in the documentation for the affected (Expr.betaReduce and Term.betaReduce) methods.
1 parent 97677cc commit f3dda82

File tree

5 files changed

+173
-21
lines changed

5 files changed

+173
-21
lines changed

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -373,17 +373,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
373373
end TermTypeTest
374374

375375
object Term extends TermModule:
376-
def betaReduce(tree: Term): Option[Term] =
377-
tree match
378-
case tpd.Block(Nil, expr) =>
379-
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
380-
case tpd.Inlined(_, Nil, expr) =>
381-
betaReduce(expr)
382-
case _ =>
383-
val tree1 = dotc.transform.BetaReduce(tree)
384-
if tree1 eq tree then None
385-
else Some(tree1.withSpan(tree.span))
386-
376+
def betaReduce(tree: Term): Option[Term] =
377+
val tree1 = new dotty.tools.dotc.ast.tpd.TreeMap {
378+
override def transform(tree: Tree)(using Context): Tree = tree match {
379+
case tpd.Block(Nil, _) | tpd.Inlined(_, Nil, _) =>
380+
super.transform(tree)
381+
case tpd.Apply(sel @ tpd.Select(expr, nme), args) =>
382+
val tree1 = cpy.Apply(tree)(cpy.Select(sel)(transform(expr), nme), args)
383+
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
384+
case tpd.Apply(ta @ tpd.TypeApply(sel @ tpd.Select(expr: Apply, nme), tpts), args) =>
385+
val tree1 = cpy.Apply(tree)(cpy.TypeApply(ta)(cpy.Select(sel)(transform(expr), nme), tpts), args)
386+
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
387+
case _ =>
388+
dotc.transform.BetaReduce(tree).withSpan(tree.span)
389+
}
390+
}.transform(tree)
391+
if tree1 == tree then None else Some(tree1)
387392
end Term
388393

389394
given TermMethods: TermMethods with

library/src/scala/quoted/Expr.scala

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,34 @@ abstract class Expr[+T] private[scala] ()
1010
object Expr {
1111

1212
/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
13-
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
14-
* then it optimizes this the top most call by returning the result of beta-reducing the application.
15-
* Otherwise returns `expr`.
13+
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
14+
* then it optimizes the top most call by returning the result of beta-reducing the application.
15+
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
16+
* Otherwise returns `expr`.
1617
*
17-
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
18-
* Some bindings may be elided as an early optimization.
18+
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
19+
* Some bindings may be elided as an early optimization.
20+
*
21+
* Example:
22+
* ```scala sc:nocompile
23+
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
24+
* ```
25+
* will be reduced to
26+
* ```scala sc:nocompile
27+
* type X1 = Tx1
28+
* type Y1 = Ty1
29+
* ...
30+
* val x1 = myX1
31+
* val y1 = myY1
32+
* ...
33+
* type Xn = Txn
34+
* type Yn = Tyn
35+
* ...
36+
* val xn = myXn
37+
* val yn = myYn
38+
* ...
39+
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
40+
* ```
1941
*/
2042
def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] =
2143
import quotes.reflect._

library/src/scala/quoted/Quotes.scala

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,14 +751,48 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
751751
/** Methods of the module object `val Term` */
752752
trait TermModule { this: Term.type =>
753753

754-
/** Returns a term that is functionally equivalent to `t`,
754+
/** Returns a term that is functionally equivalent to `t`,
755755
* however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
756-
* then it optimizes this the top most call by returning the `Some`
757-
* with the result of beta-reducing the application.
756+
* then it optimizes the top most call by returning the `Some`
757+
* with the result of beta-reducing the function application.
758+
* Similarly, all outermost curried function applications will be
759+
* beta-reduced, if possible.
758760
* Otherwise returns `None`.
759761
*
760-
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
761-
* Some bindings may be elided as an early optimization.
762+
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
763+
* Some bindings may be elided as an early optimization.
764+
*
765+
* Example:
766+
* ```scala scnocompile
767+
* ((a, ...) => f(a, ...))(x, ...)
768+
* ```
769+
* will be reduced to
770+
* ```scala sc.nocompile
771+
* val x = a
772+
* ...
773+
* f(x, ...)
774+
* ```
775+
*
776+
* In general:
777+
* ```scala sc:nocompile
778+
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
779+
* ```
780+
* will be reduced to
781+
* ```scala sc:nocompile
782+
* type X1 = Tx1
783+
* type Y1 = Ty1
784+
* ...
785+
* val x1 = myX1
786+
* val y1 = myY1
787+
* ...
788+
* type Xn = Txn
789+
* type Yn = Tyn
790+
* ...
791+
* val xn = myXn
792+
* val yn = myYn
793+
* ...
794+
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
795+
* ```
762796
*/
763797
def betaReduce(term: Term): Option[Term]
764798

tests/pos-macros/i17506/Macro_1.scala

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
class Foo
2+
class Bar
3+
class Baz
4+
5+
import scala.quoted._
6+
7+
def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term =
8+
import quotes.reflect._
9+
val reducedMaybe = Term.betaReduce(applied.asTerm)
10+
assert(reducedMaybe.isDefined)
11+
val reduced = reducedMaybe.get
12+
assert(reduced.show == expected,s"obtained: ${reduced.show}, expected: ${expected}")
13+
reduced
14+
15+
inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit =
16+
${regularCurriedCtxFun2BetaReduceTestImpl('f)}
17+
def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] =
18+
val expected =
19+
"""|{
20+
| val contextual$3: Bar = new Bar()
21+
| val contextual$2: Foo = new Foo()
22+
| 123
23+
|}""".stripMargin
24+
val applied = '{$f(using new Foo())(using new Bar())}
25+
assertBetaReduction(applied, expected).asExprOf[Int]
26+
27+
inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int =
28+
${regularCurriedFun2BetaReduceTestImpl('f)}
29+
def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] =
30+
val expected =
31+
"""|{
32+
| val b: Bar = new Bar()
33+
| val f: Foo = new Foo()
34+
| 123
35+
|}""".stripMargin
36+
val applied = '{$f(new Foo())(new Bar())}
37+
assertBetaReduction(applied, expected).asExprOf[Int]
38+
39+
inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit =
40+
${typeParamCurriedFun2BetaReduceTestImpl('f)}
41+
def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] =
42+
val expected =
43+
"""|{
44+
| type Y = Bar
45+
| val y: Bar = new Bar()
46+
| type X = Foo
47+
| val x: Foo = new Foo()
48+
| typeParamFun2[Y, X](y, x)
49+
|}""".stripMargin
50+
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())}
51+
assertBetaReduction(applied, expected).asExprOf[Unit]
52+
53+
inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int =
54+
${regularCurriedFun3BetaReduceTestImpl('f)}
55+
def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] =
56+
val expected =
57+
"""|{
58+
| val i: Baz = new Baz()
59+
| val b: Bar = new Bar()
60+
| val f: Foo = new Foo()
61+
| 123
62+
|}""".stripMargin
63+
val applied = '{$f(new Foo())(new Bar())(new Baz())}
64+
assertBetaReduction(applied, expected).asExprOf[Int]
65+
66+
inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit =
67+
${typeParamCurriedFun3BetaReduceTestImpl('f)}
68+
def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] =
69+
val expected =
70+
"""|{
71+
| type Z = Baz
72+
| val z: Baz = new Baz()
73+
| type Y = Bar
74+
| val y: Bar = new Bar()
75+
| type X = Foo
76+
| val x: Foo = new Foo()
77+
| typeParamFun3[Z, Y, X](z, y, x)
78+
|}""".stripMargin
79+
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())}
80+
assertBetaReduction(applied, expected).asExprOf[Unit]

tests/pos-macros/i17506/Test_2.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@main def run() =
2+
def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString)
3+
def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString)
4+
5+
regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123)
6+
regularCurriedCtxFun2BetaReduceTest(123)
7+
regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123))
8+
typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x))
9+
10+
regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123)
11+
typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))

0 commit comments

Comments
 (0)