Skip to content

Add type refinement for abstract type bindings #4688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,9 @@ class Definitions {
lazy val TastyTastyModule = ctx.requiredModule("scala.tasty.Tasty")
lazy val TastyTasty_macroContext = TastyTastyModule.requiredMethod("macroContext")

lazy val RefinedScrutineeType: TypeRef = ctx.requiredClassRef("scala.RefinedScrutinee")
def RefinedScrutineeClass(implicit ctx: Context) = RefinedScrutineeType.symbol.asClass

lazy val EqType = ctx.requiredClassRef("scala.Eq")
def EqClass(implicit ctx: Context) = EqType.symbol.asClass
def EqModule(implicit ctx: Context) = EqClass.companionModule
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ object StdNames {
final val Type : N = "Type"
final val TypeTree: N = "TypeTree"

final val RefinedScrutinee: N = "RefinedScrutinee"

// Annotation simple names, used in Namer
final val BeanPropertyAnnot: N = "BeanProperty"
final val BooleanBeanPropertyAnnot: N = "BooleanBeanProperty"
Expand Down Expand Up @@ -490,6 +492,7 @@ object StdNames {
val raw_ : N = "raw"
val readResolve: N = "readResolve"
val reflect : N = "reflect"
val refinedScrutinee: N = "refinedScrutinee"
val reflectiveSelectable: N = "reflectiveSelectable"
val reify : N = "reify"
val rootMirror : N = "rootMirror"
Expand Down
33 changes: 33 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/FirstTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,39 @@ class FirstTransform extends MiniPhase with InfoTransformer { thisPhase =>
}

override def transformDefDef(ddef: DefDef)(implicit ctx: Context) = {
if (ddef.name == nme.unapply && !ddef.symbol.is(Synthetic)) {
ddef.tpe.widen match {
case mt: MethodType if !mt.resType.widen.isInstanceOf[MethodicType] =>
val resultType = mt.resType.substParam(mt.paramRefs.head, mt.paramRefs.head)
resultType match {
case resultType: AppliedType if resultType.derivesFrom(defn.RefinedScrutineeClass) =>
val refinedType :: resultType2 :: Nil = resultType.args
if (refinedType.exists && !(refinedType <:< mt.paramRefs.head)) {
val paramName = mt.paramNames.head
val paramTpe = mt.paramRefs.head
val paramInfo = mt.paramInfos.head
ctx.error(
i"""Extractor with ${tpnme.RefinedScrutinee} should refine the result type of that member.
|The scrutinee type of ${tpnme.RefinedScrutinee} should be a subtype of $paramTpe:
| def unapply($paramName: $paramInfo): ${tpnme.RefinedScrutinee}[$paramTpe & $refinedType, $resultType2]
""".stripMargin, ddef.tpt.pos)
}
case _ =>
val refinedType = resultType.select(nme.refinedScrutinee).widen.resultType
if (refinedType.exists && !(refinedType <:< mt.paramRefs.head)) {
val paramName = mt.paramNames.head
val paramTpe = mt.paramRefs.head
val paramInfo = mt.paramInfos.head
ctx.error(
i"""Extractor with ${nme.refinedScrutinee} should refine the result type of that member.
|The result type of ${nme.refinedScrutinee} should be a subtype of $paramTpe:
| def unapply($paramName: $paramInfo): ${resultType.widenDealias.classSymbol.name} { def ${nme.refinedScrutinee}: $refinedType & $paramTpe }
""".stripMargin, ddef.tpt.pos)
}
}
case _ =>
}
}
val meth = ddef.symbol.asTerm
if (meth.hasAnnotation(defn.NativeAnnot)) {
meth.resetFlag(Deferred)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
* whereas overloaded variants need to have a conforming variant.
*/
def trySelectUnapply(qual: untpd.Tree)(fallBack: Tree => Tree): Tree = {
// try first for non-overloaded, then for overloaded ocurrences
// try first for non-overloaded, then for overloaded occurrences
def tryWithName(name: TermName)(fallBack: Tree => Tree)(implicit ctx: Context): Tree = {
def tryWithProto(pt: Type)(implicit ctx: Context) = {
val result = typedExpr(untpd.Select(qual, name), new UnapplyFunProto(pt, this))
Expand Down
23 changes: 21 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1371,10 +1371,29 @@ class Typer extends Namer
else {
// for a singleton pattern like `x @ Nil`, `x` should get the type from the scrutinee
// see tests/neg/i3200b.scala and SI-1503
val symTp =
val symTp0 =
if (body1.tpe.isInstanceOf[TermRef]) pt1
else body1.tpe.underlyingIfRepeated(isJava = false)
val sym = ctx.newPatternBoundSymbol(tree.name, symTp, tree.pos)

// If it is name based pattern matching, the type of the argument of the unapply is abstract and
// the return type has a type member `Refined`, then refine the type of the binding with the type of `Refined`.
val symTp1 = body1 match {
case Trees.UnApply(fun, _, _) if symTp0.typeSymbol.is(Deferred) =>
// TODO check that it is name based pattern matching
fun.tpe.widen match {
case mt: MethodType if !mt.resType.isInstanceOf[MethodType] =>
val resultType = mt.resType.substParam(mt.paramRefs.head, symTp0)
val refinedType = resultType.select(nme.refinedScrutinee).widen.resultType
if (refinedType.exists) refinedType
else symTp0
case _ =>
symTp0
}

case _ => symTp0
}

val sym = ctx.newPatternBoundSymbol(tree.name, symTp1, tree.pos)
if (ctx.mode.is(Mode.InPatternAlternative))
ctx.error(i"Illegal variable ${sym.name} in pattern alternative", tree.pos)
assignType(cpy.Bind(tree)(tree.name, body1), sym)
Expand Down
25 changes: 25 additions & 0 deletions library/src/scala/RefinedScrutinee.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package scala

final class RefinedScrutinee[+Scrutinee /*<: Singleton*/, +Result] private (val result: Any) extends AnyVal {
// Scrutinee in not a singleton to provide a better error message

/** There is no result */
def isEmpty: Boolean = result == RefinedScrutinee.NoResult

/** When non-empty, get the result */
def get: Result = result.asInstanceOf[Result]

/** Scrutinee type on a successful match */
/*erased*/ def refinedScrutinee: Scrutinee = null.asInstanceOf[Scrutinee] // evaluated in RefinedScrutinee.matchOf

}

object RefinedScrutinee {

private[RefinedScrutinee] object NoResult

def matchOf[Scrutinee <: Singleton, Result](scrutinee: Scrutinee)(result: Result): RefinedScrutinee[Scrutinee, Result] = new RefinedScrutinee(result)

def noMatch[Scrutinee <: Singleton, Result]: RefinedScrutinee[Scrutinee, Result] = new RefinedScrutinee(NoResult)

}
46 changes: 46 additions & 0 deletions tests/neg/refined-binding-nat-2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

trait Peano {
type Nat
type Zero <: Nat
type Succ <: Nat

val Zero: Zero

val Succ: SuccExtractor
trait SuccExtractor {
def apply(nat: Nat): Succ
def unapply(nat: Nat): RefinedScrutinee[Succ, Nat] // error: Extractor with RefinedScrutinee should refine the result type of that member ...
}
}

object IntNums extends Peano {
type Nat = Int
type Zero = Int
type Succ = Int

val Zero: Zero = 0

object Succ extends SuccExtractor {
def apply(nat: Nat): Succ = nat + 1
def unapply(nat: Nat) = // error: Extractor with RefinedScrutinee should refine the result type of that member ...
if (nat == 0) RefinedScrutinee.noMatch
else RefinedScrutinee.matchOf(nat)(nat - 1)
}

}

object IntNums2 extends Peano {
type Nat = Int
type Zero = Int
type Succ = Int

val Zero: Zero = 0

object Succ extends SuccExtractor {
def apply(nat: Nat): Succ = nat + 1
def unapply(nat: Nat): RefinedScrutinee[nat.type & Succ, Nat] =
if (nat == 0) RefinedScrutinee.noMatch
else RefinedScrutinee.matchOf(nat)(nat - 1)
}

}
55 changes: 55 additions & 0 deletions tests/neg/refined-binding-nat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@

trait Peano {
type Nat
type Zero <: Nat
type Succ <: Nat

val Zero: Zero

val Succ: SuccExtractor
trait SuccExtractor {
def apply(nat: Nat): Succ
def unapply(nat: Nat): SuccOpt // error: Extractor with refinedScrutinee should refine the result type of that member ...
}
trait SuccOpt {
def isEmpty: Boolean
def refinedScrutinee: Succ
def get: Nat
}
}

object IntNums extends Peano {
type Nat = Int
type Zero = Int
type Succ = Int

val Zero: Zero = 0

object Succ extends SuccExtractor {
def apply(nat: Nat): Succ = nat + 1
def unapply(nat: Nat) = new SuccOpt { // error: Extractor with refinedScrutinee should refine the result type of that member ...
def isEmpty: Boolean = nat == 0
def refinedScrutinee: Succ & nat.type = nat
def get: Int = nat - 1
}
}

}

object IntNums2 extends Peano {
type Nat = Int
type Zero = Int
type Succ = Int

val Zero: Zero = 0

object Succ extends SuccExtractor {
def apply(nat: Nat): Succ = nat + 1
def unapply(nat: Nat): SuccOpt { def refinedScrutinee: Succ & nat.type } = new SuccOpt {
def isEmpty: Boolean = nat == 0
def refinedScrutinee: Succ & nat.type = nat
def get: Int = nat - 1
}
}

}
2 changes: 2 additions & 0 deletions tests/run/refined-binding-2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ok
9
29 changes: 29 additions & 0 deletions tests/run/refined-binding-2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

sealed trait Foo {

type X
type Y <: X

def x: X

def f(y: Y) = println("ok")

object Z {
def unapply(arg: X): RefinedScrutinee[arg.type & Y, Int] =
RefinedScrutinee.matchOf(arg.asInstanceOf[arg.type & Y])(9)
}
}

object Test {
def main(args: Array[String]): Unit = {
test(new Foo { type X = Int; type Y = Int; def x: X = 1 })
}

def test(foo: Foo): Unit = {
foo.x match {
case x @ foo.Z(i) => // `x` is refined to type `Y`
foo.f(x)
println(i)
}
}
}
3 changes: 3 additions & 0 deletions tests/run/refined-binding-3.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
refinedScrutinee
ok
9
41 changes: 41 additions & 0 deletions tests/run/refined-binding-3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

sealed trait Foo {

type X
type Y <: X

def x: X

def f(y: Y) = println("ok")

object Z {
def unapply(arg: X) = new Opt {
type Scrutinee = arg.type
def refinedScrutinee: Y & Scrutinee = {
println("refinedScrutinee")
arg.asInstanceOf[Y & Scrutinee]
}
}
}

abstract class Opt {
type Scrutinee <: Singleton
def refinedScrutinee: Y & Scrutinee
def get: Int = 9
def isEmpty: Boolean = false
}
}

object Test {
def main(args: Array[String]): Unit = {
test(new Foo { type X = Int; type Y = Int; def x: X = 1 })
}

def test(foo: Foo): Unit = {
foo.x match {
case x @ foo.Z(i) => // `x` is refined to type `Y`
foo.f(x)
println(i)
}
}
}
6 changes: 6 additions & 0 deletions tests/run/refined-binding-nat-2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Some((SuccClass(SuccClass(ZeroObject)),SuccClass(ZeroObject)))
Some((ZeroObject,SuccClass(SuccClass(ZeroObject))))
None
Some((2,1))
Some((0,2))
None
Loading