diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index b8943dcfeae6..62fd9cc45751 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -282,12 +282,16 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { val parents1 = if (parents.head.classSymbol.is(Trait)) parents.head.parents.head :: parents else parents - val cls = ctx.newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic, parents1, + val cls = ctx.newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1, coord = fns.map(_.pos).reduceLeft(_ union _)) val constr = ctx.newConstructor(cls, Synthetic, Nil, Nil).entered def forwarder(fn: TermSymbol, name: TermName) = { - val fwdMeth = fn.copy(cls, name, Synthetic | Method).entered.asTerm - DefDef(fwdMeth, prefss => ref(fn).appliedToArgss(prefss)) + var flags = Synthetic | Method | Final + def isOverriden(denot: SingleDenotation) = fn.info.overrides(denot.info, matchLoosely = true) + val isOverride = parents.exists(_.member(name).hasAltWith(isOverriden)) + if (isOverride) flags = flags | Override + val fwdMeth = fn.copy(cls, name, flags).entered.asTerm + polyDefDef(fwdMeth, tprefs => prefss => ref(fn).appliedToTypes(tprefs).appliedToArgss(prefss)) } val forwarders = (fns, methNames).zipped.map(forwarder) val cdef = ClassDef(cls, DefDef(constr), forwarders) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 95c5aa94dd37..a695499e02e7 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -585,8 +585,14 @@ class Definitions { lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction") def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass + lazy val PartialFunction_isDefinedAtR = PartialFunctionClass.requiredMethodRef(nme.isDefinedAt) + def PartialFunction_isDefinedAt(implicit ctx: Context) = PartialFunction_isDefinedAtR.symbol + lazy val PartialFunction_applyOrElseR = PartialFunctionClass.requiredMethodRef(nme.applyOrElse) + def PartialFunction_applyOrElse(implicit ctx: Context) = PartialFunction_applyOrElseR.symbol + lazy val AbstractPartialFunctionType: TypeRef = ctx.requiredClassRef("scala.runtime.AbstractPartialFunction") def AbstractPartialFunctionClass(implicit ctx: Context) = AbstractPartialFunctionType.symbol.asClass + lazy val FunctionXXLType: TypeRef = ctx.requiredClassRef("scala.FunctionXXL") def FunctionXXLClass(implicit ctx: Context) = FunctionXXLType.symbol.asClass diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index aa269beea170..e2c04b920250 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -8,12 +8,13 @@ import MegaPhase._ import SymUtils._ import ast.untpd import ast.Trees._ +import dotty.tools.dotc.reporting.diagnostic.messages.TypeMismatch import dotty.tools.dotc.util.Positions.Position /** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes. * These fall into five categories * - * 1. Partial function closures, we need to generate a isDefinedAt method for these. + * 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these. * 2. Closures implementing non-trait classes. * 3. Closures implementing classes that inherit from a class other than Object * (a lambda cannot not be a run-time subtype of such a class) @@ -35,8 +36,8 @@ class ExpandSAMs extends MiniPhase { tpt.tpe match { case NoType => tree // it's a plain function case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) => - checkRefinements(tpe, fn.pos) - toPartialFunction(tree) + val tpe1 = checkRefinements(tpe, fn.pos) + toPartialFunction(tree, tpe1) case tpe @ SAMType(_) if isPlatformSam(tpe.classSymbol.asClass) => checkRefinements(tpe, fn.pos) tree @@ -50,42 +51,75 @@ class ExpandSAMs extends MiniPhase { tree } - private def toPartialFunction(tree: Block)(implicit ctx: Context): Tree = { - val Block( - (applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil, - Closure(_, _, tpt)) = tree - val applyRhs: Tree = applyDef.rhs - val applyFn = applyDef.symbol.asTerm - - val MethodTpe(paramNames, paramTypes, _) = applyFn.info - val isDefinedAtFn = applyFn.copy( - name = nme.isDefinedAt, - flags = Synthetic | Method, - info = MethodType(paramNames, paramTypes, defn.BooleanType)).asTerm - val tru = Literal(Constant(true)) - def isDefinedAtRhs(paramRefss: List[List[Tree]]) = applyRhs match { - case Match(selector, cases) => - assert(selector.symbol == param.symbol) - val paramRef = paramRefss.head.head - // Again, the alternative - // val List(List(paramRef)) = paramRefs - // fails with a similar self instantiation error - def translateCase(cdef: CaseDef): CaseDef = - cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn) - val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD, Synthetic, selector.tpe.widen) - val defaultCase = - CaseDef( - Bind(defaultSym, Underscore(selector.tpe.widen)), - EmptyTree, - Literal(Constant(false))) - val annotated = Annotated(paramRef, New(ref(defn.UncheckedAnnotType))) - cpy.Match(applyRhs)(annotated, cases.map(translateCase) :+ defaultCase) + private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = { + // /** An extractor for match, either contained in a block or standalone. */ + object PartialFunctionRHS { + def unapply(tree: Tree): Option[Match] = tree match { + case Block(Nil, expr) => unapply(expr) + case m: Match => Some(m) + case _ => None + } + } + + val closureDef(anon @ DefDef(_, _, List(List(param)), _, _)) = tree + anon.rhs match { + case PartialFunctionRHS(pf) => + val anonSym = anon.symbol + + def overrideSym(sym: Symbol) = sym.copy( + owner = anonSym.owner, + flags = Synthetic | Method | Final, + info = tpe.memberInfo(sym), + coord = tree.pos).asTerm + val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) + val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) + + def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = { + val selector = tree.selector + val selectorTpe = selector.tpe.widen + val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe) + val defaultCase = + CaseDef( + Bind(defaultSym, Underscore(selectorTpe)), + EmptyTree, + defaultValue) + val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType))) + cpy.Match(tree)(unchecked, cases :+ defaultCase) + .subst(param.symbol :: Nil, pfParam :: Nil) + // Needed because a partial function can be written as: + // param => param match { case "foo" if foo(param) => param } + // And we need to update all references to 'param' + } + + def isDefinedAtRhs(paramRefss: List[List[Tree]]) = { + val tru = Literal(Constant(true)) + def translateCase(cdef: CaseDef) = + cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn) + val paramRef = paramRefss.head.head + val defaultValue = Literal(Constant(false)) + translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue) + } + + def applyOrElseRhs(paramRefss: List[List[Tree]]) = { + val List(paramRef, defaultRef) = paramRefss.head + def translateCase(cdef: CaseDef) = + cdef.changeOwner(anonSym, applyOrElseFn) + val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef) + translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue) + } + + val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) + val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_))) + + val parent = defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos) + val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse)) + cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls) + case _ => - tru + val found = tpe.baseType(defn.FunctionClass(1)) + ctx.error(TypeMismatch(found, tpe), tree.pos) + tree } - val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) - val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn), List(nme.apply, nme.isDefinedAt)) - cpy.Block(tree)(List(applyDef, isDefinedAtDef), anonCls) } private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe.dealias match { @@ -93,7 +127,7 @@ class ExpandSAMs extends MiniPhase { if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement ctx.error("Lambda does not define " + name, pos) checkRefinements(parent, pos) - case _ => + case tpe => tpe } diff --git a/tests/neg/i4241.scala b/tests/neg/i4241.scala new file mode 100644 index 000000000000..3d93a44a015a --- /dev/null +++ b/tests/neg/i4241.scala @@ -0,0 +1,12 @@ +class Test { + def test: Unit = { + val a: PartialFunction[Int, Int] = { case x => x } + val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case _ => 2 } + val c: PartialFunction[Int, Int] = x => { x match { case y => y } } + val d: PartialFunction[Int, Int] = x => { { x match { case y => y } } } + + val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case y => y } } // error + val f: PartialFunction[Int, Int] = x => x // error + val g: PartialFunction[Int, String] = { x => x.toString } // error + } +} diff --git a/tests/pos/i4177.scala b/tests/pos/i4177.scala new file mode 100644 index 000000000000..dfcedf92d424 --- /dev/null +++ b/tests/pos/i4177.scala @@ -0,0 +1,18 @@ +class Test { + + object Foo { def unapply(x: Int) = if (x == 2) Some(x.toString) else None } + + def test: Unit = { + val a: PartialFunction[Int, String] = { case Foo(x) => x } + val b: PartialFunction[Int, String] = { case x => x.toString } + + val e: PartialFunction[String, String] = { case x @ "abc" => x } + val f: PartialFunction[String, String] = x => x match { case "abc" => x } + val g: PartialFunction[String, String] = x => x match { case "abc" if x.isEmpty => x } + + type P = PartialFunction[String,String] + val h: P = { case x => x.toString } + + val i: PartialFunction[Int, Int] = { x => x match { case x => x } } + } +} diff --git a/tests/run/i4177.scala b/tests/run/i4177.scala new file mode 100644 index 000000000000..fa0782b58b78 --- /dev/null +++ b/tests/run/i4177.scala @@ -0,0 +1,18 @@ +object Test { + private[this] var count = 0 + + def test(x: Int) = { count += 1; true } + + object Foo { + def unapply(x: Int): Option[Int] = { count += 1; Some(x) } + } + + def main(args: Array[String]): Unit = { + val res = List(1, 2).collect { case x if test(x) => x } + assert(count == 2) + + count = 0 + val res2 = List(1, 2).collect { case Foo(x) => x } + assert(count == 2) + } +} diff --git a/tests/run/partialFunctions.scala b/tests/run/partialFunctions.scala index ca82431ca8e8..8120b3fa886b 100644 --- a/tests/run/partialFunctions.scala +++ b/tests/run/partialFunctions.scala @@ -1,10 +1,15 @@ object Test { - def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1) + def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1) + class Foo(val field: Option[Int]) def main(args: Array[String]): Unit = { - val partialFunction: PartialFunction[Int, Int] = {case a: Int => a} + val p1: PartialFunction[Int, Int] = { case a: Int => a } + assert(takesPartialFunction(p1) == 1) - assert(takesPartialFunction(partialFunction) == 1) + val p2: PartialFunction[Foo, Int] = + foo => foo.field match { case Some(x) => x } + assert(p2.isDefinedAt(new Foo(Some(1)))) + assert(!p2.isDefinedAt(new Foo(None))) } }