Skip to content

Commit 049308f

Browse files
committed
Fix and extend refining annotation comparison
1 parent 10d23a6 commit 049308f

File tree

5 files changed

+133
-61
lines changed

5 files changed

+133
-61
lines changed

compiler/src/dotty/tools/dotc/ast/Trees.scala

-25
Original file line numberDiff line numberDiff line change
@@ -189,31 +189,6 @@ object Trees {
189189

190190
override def toText(printer: Printer): Text = printer.toText(this)
191191

192-
def sameTree(that: Tree[?]): Boolean = {
193-
def isSame(x: Any, y: Any): Boolean =
194-
x.asInstanceOf[AnyRef].eq(y.asInstanceOf[AnyRef]) || {
195-
x match {
196-
case x: Tree[?] =>
197-
y match {
198-
case y: Tree[?] => x.sameTree(y)
199-
case _ => false
200-
}
201-
case x: List[?] =>
202-
y match {
203-
case y: List[?] => x.corresponds(y)(isSame)
204-
case _ => false
205-
}
206-
case _ =>
207-
false
208-
}
209-
}
210-
this.getClass == that.getClass && {
211-
val it1 = this.productIterator
212-
val it2 = that.productIterator
213-
it1.corresponds(it2)(isSame)
214-
}
215-
}
216-
217192
override def hashCode(): Int = System.identityHashCode(this)
218193
override def equals(that: Any): Boolean = this eq that.asInstanceOf[AnyRef]
219194
}

compiler/src/dotty/tools/dotc/ast/tpd.scala

+54
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,60 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
10201020
else
10211021
applyOverloaded(tree, nme.EQ, that :: Nil, Nil, defn.BooleanType)
10221022

1023+
def sameTree(that: Tree, thisParamSyms: List[Symbol] = Nil, thatParamRefs: List[TermRef] = Nil)(using Context): Boolean =
1024+
def recur(tree1: Tree, tree2: Tree) =
1025+
tree1.sameTree(tree2, thisParamSyms, thatParamRefs)
1026+
1027+
def sameTrees(trees1: List[Tree], trees2: List[Tree]) =
1028+
trees1.corresponds(trees2)(recur)
1029+
1030+
def sameType(tp1: Type, tp2: Type) =
1031+
(tp1 frozen_=:= tp2) || (tp1.subst(thisParamSyms, thatParamRefs) frozen_=:= tp2)
1032+
1033+
val res = tree match
1034+
case Literal(_) | Ident(_) =>
1035+
sameType(tree.tpe, that.tpe)
1036+
case Select(qual1, name1) =>
1037+
that match
1038+
case Select(qual2, name2) => name1 == name2 && recur(qual1, qual2)
1039+
case _ => false
1040+
case Apply(fun1, args1) =>
1041+
that match
1042+
case Apply(fun2, args2) => recur(fun1, fun2) && sameTrees(args1, args2)
1043+
case _ => false
1044+
case TypeApply(fun1, args1) =>
1045+
that match
1046+
case TypeApply(fun2, args2) =>
1047+
recur(fun1, fun2) && args1.corresponds(args2)((arg1, arg2) => sameType(arg1.tpe, arg2.tpe))
1048+
case _ => false
1049+
case tpt1: TypeTree =>
1050+
that match
1051+
case tpt2: TypeTree => sameType(tpt1.tpe, tpt2.tpe)
1052+
case _ => false
1053+
case Typed(expr1, tpt1) =>
1054+
that match
1055+
case Typed(expr2, tpt2) => recur(expr1, expr2) && sameType(tpt1.tpe, tpt2.tpe)
1056+
case _ => false
1057+
case New(tpt1) =>
1058+
that match
1059+
case New(tpt2) => sameType(tpt1.tpe, tpt2.tpe)
1060+
case _ => false
1061+
case closureDef(def1) =>
1062+
that match
1063+
case closureDef(def2) =>
1064+
val newThisParamSyms = def1.symbol.paramSymss.flatten ++ thisParamSyms
1065+
val newThatParamRefs = def2.symbol.paramSymss.flatten.map(_.termRef) ++ thatParamRefs
1066+
def1.rhs.sameTree(def2.rhs, newThisParamSyms, newThatParamRefs)
1067+
case _ => false
1068+
case Block(stats1, expr1) =>
1069+
that match
1070+
case Block(stats2, expr2) => sameTrees(stats1, stats2) && recur(expr1, expr2)
1071+
case _ => false
1072+
case _ => false
1073+
1074+
res
1075+
1076+
10231077
/** `tree.isInstanceOf[tp]`, with special treatment of singleton types */
10241078
def isInstance(tp: Type)(using Context): Tree = tp.dealias match {
10251079
case ConstantType(c) if c.tag == StringTag =>

compiler/src/dotty/tools/dotc/core/Annotations.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ object Annotations {
4343
def argumentConstantString(i: Int)(using Context): Option[String] =
4444
for (case Constant(s: String) <- argumentConstant(i)) yield s
4545

46+
/** All type and term argument trees of this annotation in a single flat list */
47+
private def allArguments(using Context): List[Tree] = tpd.allArguments(tree)
48+
4649
/** The tree evaluation is in progress. */
4750
def isEvaluating: Boolean = false
4851

@@ -88,7 +91,8 @@ object Annotations {
8891
def ensureCompleted(using Context): Unit = tree
8992

9093
def sameAnnotation(that: Annotation)(using Context): Boolean =
91-
symbol == that.symbol && tree.sameTree(that.tree)
94+
def sameArg(arg1: Tree, arg2: Tree): Boolean = tpd.stripNamedArg(arg1).sameTree(tpd.stripNamedArg(arg2))
95+
symbol == that.symbol && allArguments.corresponds(that.allArguments)(sameArg)
9296

9397
def hasOneOfMetaAnnotation(metaSyms: Set[Symbol], orNoneOf: Set[Symbol] = Set.empty)(using Context): Boolean = atPhaseNoLater(erasurePhase) {
9498
def go(metaSyms: Set[Symbol]) =

tests/neg/annot-refining-infer.scala

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
class MyAnnotation(x: Any) extends scala.annotation.RefiningAnnotation
2+
3+
def id[T](x: T): T = x
4+
def id2[T](x: T, y: T): T = x
5+
6+
def foo1[T](x: T, g: T => Unit): T = x
7+
def foo2[T](x: T, y: T, g: T => Unit): T = x
8+
def foo3[T](g: T => Unit, x: T, y: T): T = x
9+
def foo4[T](x: T, g: T => Unit, h: T => Unit): T = x
10+
11+
def take42[T](x: T @MyAnnotation(42)): Unit = ()
12+
def take43[T](x: T @MyAnnotation(43)): Unit = ()
13+
def take42or43[S](x: S @MyAnnotation(42) | S @MyAnnotation(43)): Unit = ()
14+
def take42or43Int(x: Int @MyAnnotation(42) | Int @MyAnnotation(43)): Unit = ()
15+
16+
def main =
17+
val c42: Int @MyAnnotation(42) = ???
18+
val c43: Int @MyAnnotation(43) = ???
19+
20+
val v01 = id2[Int @MyAnnotation(42) | Int @MyAnnotation(43)](c42, c43)
21+
val v02: Int @MyAnnotation(42) | Int @MyAnnotation(43) = c42
22+
val v03: Int @MyAnnotation(42) | Int @MyAnnotation(43) = id2(c42, c43)
23+
24+
val v04 = foo1(c42, take42)
25+
val v05: Int @MyAnnotation(42) = v13
26+
val v06 = foo1(c42, take43) // error
27+
val v07 = foo1(c42, take42or43)
28+
29+
val v08 = foo2(c42, c42, take42)
30+
val v09: Int @MyAnnotation(42) = v15
31+
val v10 = foo2(c42, c43, take42) // error
32+
val v11 = foo2(c42, c43, take42or43) // error
33+
val v12 = foo2(c42, c43, take42or43Int)
34+
35+
val v13 = foo3(take42or43, c42, c43) // error
36+
val v14 = foo3(take42or43Int, c42, c43)
37+
38+
val v15 = foo4(c42, take42, take42)
39+
val v16: Int @MyAnnotation(42) = v15
40+
val v17 = foo4(c42, take42, take43) // error

tests/neg/annot-refining-sub.scala

+34-35
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def main =
2222
val c: Int = 42
2323
val o: O.type = O
2424

25-
val v1: Int @annot1(1) = ??? : Int @annot1(1) // error: fixme (constants are equal)
25+
val v1: Int @annot1(1) = ??? : Int @annot1(1)
2626
val v2: Int @annot1(c) = ??? : Int @annot1(c)
2727
val v3: Int @annot1(O.d) = ??? : Int @annot1(O.d)
28-
val v4: Int @annot1(O.d) = ??? : Int @annot1(o.d) // error: fixme?
29-
val v5: Int @annot1((1, 2)) = ??? : Int @annot1((1, 2)) // error: fixme
30-
val v6: Int @annot1(1 + 2) = ??? : Int @annot1(1 + 2) // error: fixme
31-
val v7: Int @annot1(1 + 2) = ??? : Int @annot1(2 + 1) // error: fixme? should constant fold?
32-
val v8: Int @annot1(1 + c) = ??? : Int @annot1(1 + c) // error: fixme
33-
val v9: Int @annot1(1 + c) = ??? : Int @annot1(c + 1) // error (no algebraic normalization)
34-
val v10: Int @annot1(Box(1)) = ??? : Int @annot1(Box(1)) // error: fixme
28+
val v4: Int @annot1(O.d) = ??? : Int @annot1(o.d)
29+
val v5: Int @annot1((1, 2)) = ??? : Int @annot1((1, 2))
30+
val v6: Int @annot1(1 + 2) = ??? : Int @annot1(1 + 2)
31+
val v7: Int @annot1(1 + 2) = ??? : Int @annot1(2 + 1) // error: no constant folding
32+
val v8: Int @annot1(1 + c) = ??? : Int @annot1(1 + c)
33+
val v9: Int @annot1(1 + c) = ??? : Int @annot1(c + 1) // error: no algebraic simplification
34+
val v10: Int @annot1(Box(1)) = ??? : Int @annot1(Box(1))
3535
val v11: Int @annot1(Box(c)) = ??? : Int @annot1(Box(c))
36-
val v12: Int @annot1(Box2(1)) = ??? : Int @annot1(Box2(1)) // error: fixme
36+
val v12: Int @annot1(Box2(1)) = ??? : Int @annot1(Box2(1))
3737
val v13: Int @annot1(Box2(c)) = ??? : Int @annot1(Box2(c))
3838
val v14: Int @annot1(c: Int) = ??? : Int @annot1(c: Int)
3939
val v15: Int @annot1(c) = ??? : Int @annot1(c: Int) // error
@@ -50,41 +50,40 @@ def main =
5050
val v26: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3 {type T = String}) // error
5151
val v27: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3) // error
5252
val v28: Int @annot1(a=c) = ??? : Int @annot1(a=c)
53-
val v29: Int @annot1(a=c) = ??? : Int @annot1(c) // error: fixme (same arguments, named vs positional)
54-
val v30: Int @annot1(c) = ??? : Int @annot1(a=c) // error: fixme
53+
val v29: Int @annot1(a=c) = ??? : Int @annot1(c)
54+
val v30: Int @annot1(c) = ??? : Int @annot1(a=c)
5555
val v31: Int @annot1((d: Int) => d) = ??? : Int @annot1((d: Int) => d)
56-
val v32: Int @annot1((d: Int) => d) = ??? : Int @annot1((e: Int) => e) // error: fixme (alpha equivalence)
57-
val v33: Int @annot1((e: Int) => e) = ??? : Int @annot1((d: Int) => d) // error: fixme
58-
val v34: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1) // error: fixme
59-
val v35: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1) // error: fixme
60-
val v36: Int @annot1((d: Int) => id[d.type]) = ??? : Int @annot1((e: Int) => id[e.type]) // error: fixme
61-
val v37: Int @annot1((d: Box3) => id[d.T]) = ??? : Int @annot1((e: Box3) => id[e.T]) // error: fixme
62-
val v38: Int @annot1((d: Int) => (d: Int) => d) = ??? : Int @annot1((e: Int) => (e: Int) => e) // error: fixme
63-
val v39: Int @annot1((d: Int) => ((e: Int) => d)(2)) = ??? : Int @annot1((e: Int) => ((e: Int) => e)(2)) // error: fixme
64-
65-
val v40: Int @annot2(1, 2) = ??? : Int @annot2(1, 2) // error: fixme
56+
val v32: Int @annot1((d: Int) => d) = ??? : Int @annot1((e: Int) => e)
57+
val v33: Int @annot1((e: Int) => e) = ??? : Int @annot1((d: Int) => d)
58+
val v34: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1)
59+
val v35: Int @annot1((d: Int) => id(d)) = ??? : Int @annot1((e: Int) => id(e))
60+
val v36: Int @annot1((d: Int) => id[d.type]) = ??? : Int @annot1((e: Int) => id[e.type])
61+
val v37: Int @annot1((d: Box3) => id[d.T]) = ??? : Int @annot1((e: Box3) => id[e.T])
62+
val v38: Int @annot1((d: Int) => (d: Int) => d) = ??? : Int @annot1((e: Int) => (e: Int) => e)
63+
val v39: Int @annot1((d: Int) => ((e: Int) => e)(2)) = ??? : Int @annot1((e: Int) => ((e: Int) => e)(2))
64+
val v40: Int @annot2(1, 2) = ??? : Int @annot2(1, 2)
6665
val v41: Int @annot2(c, c) = ??? : Int @annot2(c, c)
67-
val v42: Int @annot2(c, c) = ??? : Int @annot2(a=c, b=c) // error: fixme
68-
val v43: Int @annot2(a=c, c) = ??? : Int @annot2(c, b=c) // error: fixme
69-
val v44: Int @annot2(a=c, b=c) = ??? : Int @annot2(c, c) // error: fixme
66+
val v42: Int @annot2(c, c) = ??? : Int @annot2(a=c, b=c)
67+
val v43: Int @annot2(a=c, c) = ??? : Int @annot2(c, b=c)
68+
val v44: Int @annot2(a=c, b=c) = ??? : Int @annot2(c, c)
7069

71-
val v45: Int @annot3(1) = ??? : Int @annot3(1) // error: fixme
70+
val v45: Int @annot3(1) = ??? : Int @annot3(1)
7271
val v46: Int @annot3(c) = ??? : Int @annot3(c)
73-
val v47: Int @annot3(1) = ??? : Int @annot3(1, 3) // error: fixme
74-
val v48: Int @annot3(1, 3) = ??? : Int @annot3(1) // error: fixme
75-
val v49: Int @annot3(c) = ??? : Int @annot3(c, 3) // error: fixme
76-
val v50: Int @annot3(c, 3) = ??? : Int @annot3(c) // error: fixme
72+
val v47: Int @annot3(1) = ??? : Int @annot3(1, 3) // error: default arg tree is different, fix in the future?
73+
val v48: Int @annot3(1, 3) = ??? : Int @annot3(1) // error: same as above
74+
val v49: Int @annot3(c) = ??? : Int @annot3(c, 3) // error: same as above
75+
val v50: Int @annot3(c, 3) = ??? : Int @annot3(c) // error: same as above
7776

78-
val v51: Int @annot4[1] = ??? : Int @annot4[1] // error: fixme
77+
val v51: Int @annot4[1] = ??? : Int @annot4[1]
7978
val v52: Int @annot4[c.type] = ??? : Int @annot4[c.type]
8079
val v53: Int @annot4[O.d.type] = ??? : Int @annot4[O.d.type]
81-
val v54: Int @annot4[O.d.type] = ??? : Int @annot4[o.d.type]// error: fixme?
80+
val v54: Int @annot4[O.d.type] = ??? : Int @annot4[o.d.type]
8281
val v55: Int @annot4[Int] = ??? : Int @annot4[Int]
8382
val v56: Int @annot4[Int] = ??? : Int @annot4[1] // error
84-
val v57: Int @annot4[(1, 2)] = ??? : Int @annot4[(1, 2)] // error: fixme
85-
val v58: Int @annot4[1 + 2] = ??? : Int @annot4[1 + 2] // error: fixme
86-
val v59: Int @annot4[1 + 2] = ??? : Int @annot4[2 + 1] // error: fixme
87-
val v60: Int @annot4[1 + c.type] = ??? : Int @annot4[1 + c.type] // error: fixme
83+
val v57: Int @annot4[(1, 2)] = ??? : Int @annot4[(1, 2)]
84+
val v58: Int @annot4[1 + 2] = ??? : Int @annot4[1 + 2]
85+
val v59: Int @annot4[1 + 2] = ??? : Int @annot4[2 + 1]
86+
val v60: Int @annot4[1 + c.type] = ??? : Int @annot4[1 + c.type]
8887
val v61: Int @annot4[1 + c.type] = ??? : Int @annot4[c.type + 1] // error
8988
val v62: Int @annot4[Box[Int]] = ??? : Int @annot4[Box[Int]]
9089
val v63: Int @annot4[Box[String]] = ??? : Int @annot4[Box[Int]] // error

0 commit comments

Comments
 (0)