@@ -1296,6 +1296,9 @@ object Trees {
1296
1296
*/
1297
1297
protected def inlineContext (call : Tree )(using Context ): Context = ctx
1298
1298
1299
+ /** The context to use when mapping or accumulating over a tree */
1300
+ def localCtx (tree : Tree )(using Context ): Context
1301
+
1299
1302
abstract class TreeMap (val cpy : TreeCopier = inst.cpy) { self =>
1300
1303
def transform (tree : Tree )(using Context ): Tree = {
1301
1304
inContext(
@@ -1304,9 +1307,6 @@ object Trees {
1304
1307
else ctx
1305
1308
){
1306
1309
Stats .record(s " TreeMap.transform/ $getClass" )
1307
- def localCtx =
1308
- if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx
1309
-
1310
1310
if (skipTransform(tree)) tree
1311
1311
else tree match {
1312
1312
case Ident (name) =>
@@ -1362,11 +1362,11 @@ object Trees {
1362
1362
case AppliedTypeTree (tpt, args) =>
1363
1363
cpy.AppliedTypeTree (tree)(transform(tpt), transform(args))
1364
1364
case LambdaTypeTree (tparams, body) =>
1365
- inContext(localCtx) {
1365
+ inContext(localCtx(tree) ) {
1366
1366
cpy.LambdaTypeTree (tree)(transformSub(tparams), transform(body))
1367
1367
}
1368
1368
case TermLambdaTypeTree (params, body) =>
1369
- inContext(localCtx) {
1369
+ inContext(localCtx(tree) ) {
1370
1370
cpy.TermLambdaTypeTree (tree)(transformSub(params), transform(body))
1371
1371
}
1372
1372
case MatchTypeTree (bound, selector, cases) =>
@@ -1384,17 +1384,17 @@ object Trees {
1384
1384
case EmptyValDef =>
1385
1385
tree
1386
1386
case tree @ ValDef (name, tpt, _) =>
1387
- inContext(localCtx) {
1387
+ inContext(localCtx(tree) ) {
1388
1388
val tpt1 = transform(tpt)
1389
1389
val rhs1 = transform(tree.rhs)
1390
1390
cpy.ValDef (tree)(name, tpt1, rhs1)
1391
1391
}
1392
1392
case tree @ DefDef (name, paramss, tpt, _) =>
1393
- inContext(localCtx) {
1393
+ inContext(localCtx(tree) ) {
1394
1394
cpy.DefDef (tree)(name, transformParamss(paramss), transform(tpt), transform(tree.rhs))
1395
1395
}
1396
1396
case tree @ TypeDef (name, rhs) =>
1397
- inContext(localCtx) {
1397
+ inContext(localCtx(tree) ) {
1398
1398
cpy.TypeDef (tree)(name, transform(rhs))
1399
1399
}
1400
1400
case tree @ Template (constr, parents, self, _) if tree.derived.isEmpty =>
@@ -1404,7 +1404,10 @@ object Trees {
1404
1404
case Export (expr, selectors) =>
1405
1405
cpy.Export (tree)(transform(expr), selectors)
1406
1406
case PackageDef (pid, stats) =>
1407
- cpy.PackageDef (tree)(transformSub(pid), transformStats(stats, pid.symbol.moduleClass)(using localCtx))
1407
+ val pid1 = transformSub(pid)
1408
+ inContext(localCtx(tree)) {
1409
+ cpy.PackageDef (tree)(pid1, transformStats(stats, ctx.owner))
1410
+ }
1408
1411
case Annotated (arg, annot) =>
1409
1412
cpy.Annotated (tree)(transform(arg), transform(annot))
1410
1413
case Thicket (trees) =>
@@ -1450,8 +1453,6 @@ object Trees {
1450
1453
foldOver(x, tree)(using ctx.withSource(tree.source))
1451
1454
else {
1452
1455
Stats .record(s " TreeAccumulator.foldOver/ $getClass" )
1453
- def localCtx =
1454
- if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx
1455
1456
tree match {
1456
1457
case Ident (name) =>
1457
1458
x
@@ -1506,11 +1507,11 @@ object Trees {
1506
1507
case AppliedTypeTree (tpt, args) =>
1507
1508
this (this (x, tpt), args)
1508
1509
case LambdaTypeTree (tparams, body) =>
1509
- inContext(localCtx) {
1510
+ inContext(localCtx(tree) ) {
1510
1511
this (this (x, tparams), body)
1511
1512
}
1512
1513
case TermLambdaTypeTree (params, body) =>
1513
- inContext(localCtx) {
1514
+ inContext(localCtx(tree) ) {
1514
1515
this (this (x, params), body)
1515
1516
}
1516
1517
case MatchTypeTree (bound, selector, cases) =>
@@ -1526,15 +1527,15 @@ object Trees {
1526
1527
case UnApply (fun, implicits, patterns) =>
1527
1528
this (this (this (x, fun), implicits), patterns)
1528
1529
case tree @ ValDef (_, tpt, _) =>
1529
- inContext(localCtx) {
1530
+ inContext(localCtx(tree) ) {
1530
1531
this (this (x, tpt), tree.rhs)
1531
1532
}
1532
1533
case tree @ DefDef (_, paramss, tpt, _) =>
1533
- inContext(localCtx) {
1534
+ inContext(localCtx(tree) ) {
1534
1535
this (this (paramss.foldLeft(x)(apply), tpt), tree.rhs)
1535
1536
}
1536
1537
case TypeDef (_, rhs) =>
1537
- inContext(localCtx) {
1538
+ inContext(localCtx(tree) ) {
1538
1539
this (x, rhs)
1539
1540
}
1540
1541
case tree @ Template (constr, parents, self, _) if tree.derived.isEmpty =>
@@ -1544,7 +1545,7 @@ object Trees {
1544
1545
case Export (expr, _) =>
1545
1546
this (x, expr)
1546
1547
case PackageDef (pid, stats) =>
1547
- this (this (x, pid), stats)(using localCtx)
1548
+ this (this (x, pid), stats)(using localCtx(tree) )
1548
1549
case Annotated (arg, annot) =>
1549
1550
this (this (x, arg), annot)
1550
1551
case Thicket (ts) =>
0 commit comments