Skip to content

Commit a85d0ea

Browse files
committed
Check flags for method, val and bind symbols
1 parent b9b80b6 commit a85d0ea

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import dotty.tools.dotc.ast.untpd
99
import dotty.tools.dotc.core.Annotations
1010
import dotty.tools.dotc.core.Contexts._
1111
import dotty.tools.dotc.core.Decorators._
12-
import dotty.tools.dotc.core.Flags._
1312
import dotty.tools.dotc.core.NameKinds
1413
import dotty.tools.dotc.core.NameOps._
1514
import dotty.tools.dotc.core.StdNames._
@@ -276,12 +275,13 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
276275

277276
object DefDef extends DefDefModule:
278277
def apply(symbol: Symbol, rhsFn: List[List[Tree]] => Option[Term]): DefDef =
279-
assert(symbol.isTerm, s"expected a term symbol but received $symbol")
278+
xCheckMacroAssert(symbol.isTerm, s"expected a term symbol but received $symbol")
279+
xCheckMacroAssert(symbol.flags.is(Flags.Method), "expected a symbol with `Method` flag set")
280280
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss =>
281-
xCheckMacroedOwners(xCheckMacroValidExpr(rhsFn(prefss)), symbol).getOrElse(tpd.EmptyTree)
281+
xCheckedMacroOwners(xCheckMacroValidExpr(rhsFn(prefss)), symbol).getOrElse(tpd.EmptyTree)
282282
))
283283
def copy(original: Tree)(name: String, paramss: List[ParamClause], tpt: TypeTree, rhs: Option[Term]): DefDef =
284-
tpd.cpy.DefDef(original)(name.toTermName, paramss, tpt, xCheckMacroedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
284+
tpd.cpy.DefDef(original)(name.toTermName, paramss, tpt, xCheckedMacroOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
285285
def unapply(ddef: DefDef): (String, List[ParamClause], TypeTree, Option[Term]) =
286286
(ddef.name.toString, ddef.paramss, ddef.tpt, optional(ddef.rhs))
287287
end DefDef
@@ -307,9 +307,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
307307

308308
object ValDef extends ValDefModule:
309309
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
310-
withDefaultPos(tpd.ValDef(symbol.asTerm, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree)))
310+
xCheckMacroAssert(!symbol.flags.is(Flags.Method), "expected a symbol without `Method` flag set")
311+
withDefaultPos(tpd.ValDef(symbol.asTerm, xCheckedMacroOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree)))
311312
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
312-
tpd.cpy.ValDef(original)(name.toTermName, tpt, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), original.symbol).getOrElse(tpd.EmptyTree))
313+
tpd.cpy.ValDef(original)(name.toTermName, tpt, xCheckedMacroOwners(xCheckMacroValidExpr(rhs), original.symbol).getOrElse(tpd.EmptyTree))
313314
def unapply(vdef: ValDef): (String, TypeTree, Option[Term]) =
314315
(vdef.name.toString, vdef.tpt, optional(vdef.rhs))
315316

@@ -399,7 +400,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
399400
def etaExpand(owner: Symbol): Term = self.tpe.widen match {
400401
case mtpe: Types.MethodType if !mtpe.isParamDependent =>
401402
val closureResType = mtpe.resType match {
402-
case t: Types.MethodType => t.toFunctionType(isJava = self.symbol.is(JavaDefined))
403+
case t: Types.MethodType => t.toFunctionType(isJava = self.symbol.is(dotc.core.Flags.JavaDefined))
403404
case t => t
404405
}
405406
val closureTpe = Types.MethodType(mtpe.paramNames, mtpe.paramInfos, closureResType)
@@ -812,7 +813,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
812813
object Lambda extends LambdaModule:
813814
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block =
814815
val meth = dotc.core.Symbols.newAnonFun(owner, tpe)
815-
withDefaultPos(tpd.Closure(meth, tss => xCheckMacroedOwners(xCheckMacroValidExpr(rhsFn(meth, tss.head.map(withDefaultPos))), meth)))
816+
withDefaultPos(tpd.Closure(meth, tss => xCheckedMacroOwners(xCheckMacroValidExpr(rhsFn(meth, tss.head.map(withDefaultPos))), meth)))
816817

817818
def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
818819
case Block((ddef @ DefDef(_, tpd.ValDefs(params) :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
@@ -1483,6 +1484,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
14831484

14841485
object Bind extends BindModule:
14851486
def apply(sym: Symbol, pattern: Tree): Bind =
1487+
xCheckMacroAssert(sym.flags.is(Flags.Case), "expected a symbol with `Case` flag set")
14861488
withDefaultPos(tpd.Bind(sym, pattern))
14871489
def copy(original: Tree)(name: String, pattern: Tree): Bind =
14881490
withDefaultPos(tpd.cpy.Bind(original)(name.toTermName, pattern))
@@ -2514,13 +2516,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
25142516
def newMethod(owner: Symbol, name: String, tpe: TypeRepr): Symbol =
25152517
newMethod(owner, name, tpe, Flags.EmptyFlags, noSymbol)
25162518
def newMethod(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
2519+
checkValidFlags(flags.toTermFlags, Flags.validMethodFlags)
25172520
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | dotc.core.Flags.Method, tpe, privateWithin)
25182521
def newVal(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
2522+
checkValidFlags(flags.toTermFlags, Flags.validValFlags)
25192523
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags, tpe, privateWithin)
25202524
def newBind(owner: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol =
2521-
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | Case, tpe)
2525+
checkValidFlags(flags.toTermFlags, Flags.validBindFlags)
2526+
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | dotc.core.Flags.Case, tpe)
25222527
def noSymbol: Symbol = dotc.core.Symbols.NoSymbol
25232528

2529+
private inline def checkValidFlags(inline flags: Flags, inline valid: Flags): Unit =
2530+
xCheckMacroAssert(
2531+
flags <= valid,
2532+
s"Received invalid flags. Expected flags ${flags.show} to only contain a subset of ${valid.show}."
2533+
)
2534+
25242535
def freshName(prefix: String): String =
25252536
NameKinds.MacroNames.fresh(prefix.toTermName).toString
25262537
end Symbol
@@ -2593,7 +2604,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
25932604
self.isTerm && !self.is(dotc.core.Flags.Method) && !self.is(dotc.core.Flags.Case/*, FIXME add this check and fix sourcecode butNot = Enum | Module*/)
25942605
def isDefDef: Boolean = self.is(dotc.core.Flags.Method)
25952606
def isBind: Boolean =
2596-
self.is(dotc.core.Flags.Case, butNot = Enum | Module) && !self.isClass
2607+
self.is(dotc.core.Flags.Case, butNot = dotc.core.Flags.Enum | dotc.core.Flags.Module) && !self.isClass
25972608
def isNoSymbol: Boolean = self == Symbol.noSymbol
25982609
def exists: Boolean = self != Symbol.noSymbol
25992610

@@ -2829,6 +2840,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
28292840
def Synthetic: Flags = dotc.core.Flags.Synthetic
28302841
def Trait: Flags = dotc.core.Flags.Trait
28312842
def Transparent: Flags = dotc.core.Flags.Transparent
2843+
2844+
private[QuotesImpl] def validMethodFlags: Flags = Private | Protected | Override | Deferred | Final | Method | Implicit | Given | Local | JavaStatic | AbsOverride // Synthetic | ExtensionMethod | Exported | Erased | Infix | Invisible
2845+
private[QuotesImpl] def validValFlags: Flags = Private | Protected | Override | Deferred | Final | Param | Implicit | Lazy | Mutable | Local | ParamAccessor | Module | Package | Case | CaseAccessor | Given | Enum | JavaStatic | AbsOverride // Synthetic | Erased | Invisible
2846+
private[QuotesImpl] def validBindFlags: Flags = Case // | Implicit | Given | Erased
28322847
end Flags
28332848

28342849
given FlagsMethods: FlagsMethods with
@@ -2949,7 +2964,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
29492964
/** Checks that all definitions in this tree have the expected owner.
29502965
* Nested definitions are ignored and assumed to be correct by construction.
29512966
*/
2952-
private def xCheckMacroedOwners(tree: Option[Tree], owner: Symbol): tree.type =
2967+
private def xCheckedMacroOwners(tree: Option[Tree], owner: Symbol): tree.type =
29532968
if xCheckMacro then
29542969
tree match
29552970
case Some(tree) =>
@@ -2960,7 +2975,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
29602975
/** Checks that all definitions in this tree have the expected owner.
29612976
* Nested definitions are ignored and assumed to be correct by construction.
29622977
*/
2963-
private def xCheckMacroedOwners(tree: Tree, owner: Symbol): tree.type =
2978+
private def xCheckedMacroOwners(tree: Tree, owner: Symbol): tree.type =
29642979
if xCheckMacro then
29652980
xCheckMacroOwners(tree, owner)
29662981
tree
@@ -3031,6 +3046,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
30313046
"Reference to a method must be eta-expanded before it is used as an expression: " + term.show)
30323047
term
30333048

3049+
private inline def xCheckMacroAssert(inline cond: Boolean, inline msg: String): Unit =
3050+
assert(!xCheckMacro || cond, msg)
3051+
30343052
object Printer extends PrinterModule:
30353053

30363054
lazy val TreeCode: Printer[Tree] = new Printer[Tree]:

library/src/scala/quoted/Quotes.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3747,7 +3747,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
37473747
* @param parent The owner of the method
37483748
* @param name The name of the method
37493749
* @param tpe The type of the method (MethodType, PolyType, ByNameType)
3750-
* @param flags extra flags to with which the symbol should be constructed
3750+
* @param flags extra flags to with which the symbol should be constructed. `Method` flag will be added. Can be `Private | Protected | Override | Deferred | Final | Method | Implicit | Given | Local | JavaStatic`
37513751
* @param privateWithin the symbol within which this new method symbol should be private. May be noSymbol.
37523752
*/
37533753
def newMethod(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol
@@ -3763,7 +3763,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
37633763
* @param parent The owner of the val/var/lazy val
37643764
* @param name The name of the val/var/lazy val
37653765
* @param tpe The type of the val/var/lazy val
3766-
* @param flags extra flags to with which the symbol should be constructed
3766+
* @param flags extra flags to with which the symbol should be constructed. Can be `Private | Protected | Override | Deferred | Final | Param | Implicit | Lazy | Mutable | Local | ParamAccessor | Module | Package | Case | CaseAccessor | Given | Enum | JavaStatic`
37673767
* @param privateWithin the symbol within which this new method symbol should be private. May be noSymbol.
37683768
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
37693769
* direct or indirect children of the reflection context's owner.
@@ -3778,7 +3778,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
37783778
*
37793779
* @param parent The owner of the binding
37803780
* @param name The name of the binding
3781-
* @param flags extra flags to with which the symbol should be constructed
3781+
* @param flags extra flags to with which the symbol should be constructed. `Case` flag will be added. Can be `Case`
37823782
* @param tpe The type of the binding
37833783
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
37843784
* direct or indirect children of the reflection context's owner.

0 commit comments

Comments
 (0)