From ac489dbb6bd9a18044ce24f247970de8637e3617 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 4 Apr 2018 18:31:13 +0200 Subject: [PATCH 1/5] Fix #4177: Generate optimised applyOrElse implementation for partial function literals --- compiler/src/dotty/tools/dotc/ast/tpd.scala | 8 +- .../dotty/tools/dotc/core/Definitions.scala | 4 + .../tools/dotc/transform/ExpandSAMs.scala | 80 +++++++++++++------ tests/pos/i4177.scala | 15 ++++ tests/run/i4177.scala | 18 +++++ 5 files changed, 99 insertions(+), 26 deletions(-) create mode 100644 tests/pos/i4177.scala create mode 100644 tests/run/i4177.scala diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index b8943dcfeae6..419555305b8e 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -286,8 +286,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { coord = fns.map(_.pos).reduceLeft(_ union _)) val constr = ctx.newConstructor(cls, Synthetic, Nil, Nil).entered def forwarder(fn: TermSymbol, name: TermName) = { - val fwdMeth = fn.copy(cls, name, Synthetic | Method).entered.asTerm - DefDef(fwdMeth, prefss => ref(fn).appliedToArgss(prefss)) + var flags = Synthetic | Method + def isOverriden(denot: SingleDenotation) = fn.info.overrides(denot.info, matchLoosely = true) + val isOverride = parents.exists(_.member(name).hasAltWith(isOverriden)) + if (isOverride) flags = flags | Override + val fwdMeth = fn.copy(cls, name, flags).entered.asTerm + polyDefDef(fwdMeth, tprefs => prefss => ref(fn).appliedToTypes(tprefs).appliedToArgss(prefss)) } val forwarders = (fns, methNames).zipped.map(forwarder) val cdef = ClassDef(cls, DefDef(constr), forwarders) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 95c5aa94dd37..aabd6aca524a 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -585,6 +585,10 @@ class Definitions { lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction") def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass + + lazy val PartialFunction_applyOrElseR = PartialFunctionClass.requiredMethodRef(nme.applyOrElse) + def PartialFunction_applyOrElse(implicit ctx: Context) = PartialFunction_applyOrElseR.symbol + lazy val AbstractPartialFunctionType: TypeRef = ctx.requiredClassRef("scala.runtime.AbstractPartialFunction") def AbstractPartialFunctionClass(implicit ctx: Context) = AbstractPartialFunctionType.symbol.asClass lazy val FunctionXXLType: TypeRef = ctx.requiredClassRef("scala.FunctionXXL") diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index aa269beea170..0d0bdc972976 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -13,7 +13,7 @@ import dotty.tools.dotc.util.Positions.Position /** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes. * These fall into five categories * - * 1. Partial function closures, we need to generate a isDefinedAt method for these. + * 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these. * 2. Closures implementing non-trait classes. * 3. Closures implementing classes that inherit from a class other than Object * (a lambda cannot not be a run-time subtype of such a class) @@ -54,7 +54,25 @@ class ExpandSAMs extends MiniPhase { val Block( (applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil, Closure(_, _, tpt)) = tree - val applyRhs: Tree = applyDef.rhs + + def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = { + assert(tree.selector.symbol == param.symbol) + val selectorTpe = selector.tpe.widen + val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD, Synthetic, selectorTpe) + val defaultCase = + CaseDef( + Bind(defaultSym, Underscore(selectorTpe)), + EmptyTree, + defaultValue) + val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType))) + cpy.Match(tree)(unchecked, cases :+ defaultCase) + .subst(param.symbol :: Nil, selector.symbol :: Nil) + // Needed because a partial function can be written as: + // param => param match { case "foo" if foo(param) => param } + // And we need to update all references to 'param' + } + + val applyRhs = applyDef.rhs val applyFn = applyDef.symbol.asTerm val MethodTpe(paramNames, paramTypes, _) = applyFn.info @@ -62,30 +80,44 @@ class ExpandSAMs extends MiniPhase { name = nme.isDefinedAt, flags = Synthetic | Method, info = MethodType(paramNames, paramTypes, defn.BooleanType)).asTerm - val tru = Literal(Constant(true)) - def isDefinedAtRhs(paramRefss: List[List[Tree]]) = applyRhs match { - case Match(selector, cases) => - assert(selector.symbol == param.symbol) - val paramRef = paramRefss.head.head - // Again, the alternative - // val List(List(paramRef)) = paramRefs - // fails with a similar self instantiation error - def translateCase(cdef: CaseDef): CaseDef = - cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn) - val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD, Synthetic, selector.tpe.widen) - val defaultCase = - CaseDef( - Bind(defaultSym, Underscore(selector.tpe.widen)), - EmptyTree, - Literal(Constant(false))) - val annotated = Annotated(paramRef, New(ref(defn.UncheckedAnnotType))) - cpy.Match(applyRhs)(annotated, cases.map(translateCase) :+ defaultCase) - case _ => - tru + + val applyOrElseFn = applyFn.copy( + name = nme.applyOrElse, + flags = Synthetic | Method, + info = tpt.tpe.memberInfo(defn.PartialFunction_applyOrElse)).asTerm + + def isDefinedAtRhs(paramRefss: List[List[Tree]]) = { + val tru = Literal(Constant(true)) + applyRhs match { + case tree @ Match(_, cases) => + def translateCase(cdef: CaseDef)= + cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn) + val paramRef = paramRefss.head.head + val defaultValue = Literal(Constant(false)) + translateMatch(tree, paramRef, cases.map(translateCase), defaultValue) + case _ => + tru + } + } + + def applyOrElseRhs(paramRefss: List[List[Tree]]) = { + val List(paramRef, defaultRef) = paramRefss.head + applyRhs match { + case tree @ Match(_, cases) => + def translateCase(cdef: CaseDef) = + cdef.changeOwner(applyFn, applyOrElseFn) + val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef) + translateMatch(tree, paramRef, cases.map(translateCase), defaultValue) + case _ => + ref(applyFn).appliedTo(paramRef) + } } + val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) - val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn), List(nme.apply, nme.isDefinedAt)) - cpy.Block(tree)(List(applyDef, isDefinedAtDef), anonCls) + val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_))) + + val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn, applyOrElseFn), List(nme.apply, nme.isDefinedAt, nme.applyOrElse)) + cpy.Block(tree)(List(applyDef, isDefinedAtDef, applyOrElseDef), anonCls) } private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe.dealias match { diff --git a/tests/pos/i4177.scala b/tests/pos/i4177.scala new file mode 100644 index 000000000000..961e1e39b289 --- /dev/null +++ b/tests/pos/i4177.scala @@ -0,0 +1,15 @@ +class Test { + + object Foo { def unapply(x: Int) = if (x == 2) Some(x.toString) else None } + + def test: Unit = { + val a: PartialFunction[Int, String] = { case Foo(x) => x } + val b: PartialFunction[Int, String] = { case x => x.toString } + val c: PartialFunction[Int, String] = { x => x.toString } + val d: PartialFunction[Int, String] = x => x.toString + + val e: PartialFunction[String, String] = { case x @ "abc" => x } + val f: PartialFunction[String, String] = x => x match { case "abc" => x } + val g: PartialFunction[String, String] = x => x match { case "abc" if x.isEmpty => x } + } +} diff --git a/tests/run/i4177.scala b/tests/run/i4177.scala new file mode 100644 index 000000000000..fa0782b58b78 --- /dev/null +++ b/tests/run/i4177.scala @@ -0,0 +1,18 @@ +object Test { + private[this] var count = 0 + + def test(x: Int) = { count += 1; true } + + object Foo { + def unapply(x: Int): Option[Int] = { count += 1; Some(x) } + } + + def main(args: Array[String]): Unit = { + val res = List(1, 2).collect { case x if test(x) => x } + assert(count == 2) + + count = 0 + val res2 = List(1, 2).collect { case Foo(x) => x } + assert(count == 2) + } +} From b9a7b2ff33c22167236532fc0cd1e1b84b075bc8 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 17 Apr 2018 21:18:03 +0200 Subject: [PATCH 2/5] Only generate `isDefined` and `applyOrElse` for partial function literals This is done by extending `scala.runtime.AbstractPartialFunction`. --- compiler/src/dotty/tools/dotc/ast/tpd.scala | 4 +-- .../dotty/tools/dotc/core/Definitions.scala | 4 ++- .../tools/dotc/transform/ExpandSAMs.scala | 31 ++++++++++--------- tests/pos/i4177.scala | 5 +++ 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 419555305b8e..62fd9cc45751 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -282,11 +282,11 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { val parents1 = if (parents.head.classSymbol.is(Trait)) parents.head.parents.head :: parents else parents - val cls = ctx.newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic, parents1, + val cls = ctx.newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1, coord = fns.map(_.pos).reduceLeft(_ union _)) val constr = ctx.newConstructor(cls, Synthetic, Nil, Nil).entered def forwarder(fn: TermSymbol, name: TermName) = { - var flags = Synthetic | Method + var flags = Synthetic | Method | Final def isOverriden(denot: SingleDenotation) = fn.info.overrides(denot.info, matchLoosely = true) val isOverride = parents.exists(_.member(name).hasAltWith(isOverriden)) if (isOverride) flags = flags | Override diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index aabd6aca524a..a695499e02e7 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -585,12 +585,14 @@ class Definitions { lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction") def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass - + lazy val PartialFunction_isDefinedAtR = PartialFunctionClass.requiredMethodRef(nme.isDefinedAt) + def PartialFunction_isDefinedAt(implicit ctx: Context) = PartialFunction_isDefinedAtR.symbol lazy val PartialFunction_applyOrElseR = PartialFunctionClass.requiredMethodRef(nme.applyOrElse) def PartialFunction_applyOrElse(implicit ctx: Context) = PartialFunction_applyOrElseR.symbol lazy val AbstractPartialFunctionType: TypeRef = ctx.requiredClassRef("scala.runtime.AbstractPartialFunction") def AbstractPartialFunctionClass(implicit ctx: Context) = AbstractPartialFunctionType.symbol.asClass + lazy val FunctionXXLType: TypeRef = ctx.requiredClassRef("scala.FunctionXXL") def FunctionXXLClass(implicit ctx: Context) = FunctionXXLType.symbol.asClass diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index 0d0bdc972976..b6faab4a28c5 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -73,24 +73,21 @@ class ExpandSAMs extends MiniPhase { } val applyRhs = applyDef.rhs - val applyFn = applyDef.symbol.asTerm + val applyFn = applyDef.symbol - val MethodTpe(paramNames, paramTypes, _) = applyFn.info - val isDefinedAtFn = applyFn.copy( - name = nme.isDefinedAt, - flags = Synthetic | Method, - info = MethodType(paramNames, paramTypes, defn.BooleanType)).asTerm - - val applyOrElseFn = applyFn.copy( - name = nme.applyOrElse, - flags = Synthetic | Method, - info = tpt.tpe.memberInfo(defn.PartialFunction_applyOrElse)).asTerm + def overrideSym(sym: Symbol) = sym.copy( + owner = applyFn.owner, + flags = Synthetic | Method | Final, + info = tpt.tpe.memberInfo(sym), + coord = tree.pos).asTerm + val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) + val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) def isDefinedAtRhs(paramRefss: List[List[Tree]]) = { val tru = Literal(Constant(true)) applyRhs match { case tree @ Match(_, cases) => - def translateCase(cdef: CaseDef)= + def translateCase(cdef: CaseDef) = cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn) val paramRef = paramRefss.head.head val defaultValue = Literal(Constant(false)) @@ -109,15 +106,19 @@ class ExpandSAMs extends MiniPhase { val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef) translateMatch(tree, paramRef, cases.map(translateCase), defaultValue) case _ => - ref(applyFn).appliedTo(paramRef) + applyRhs + .changeOwner(applyFn, applyOrElseFn) + .subst(param.symbol :: Nil, paramRef.symbol :: Nil) } } val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_))) - val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn, applyOrElseFn), List(nme.apply, nme.isDefinedAt, nme.applyOrElse)) - cpy.Block(tree)(List(applyDef, isDefinedAtDef, applyOrElseDef), anonCls) + val tpArgs = tpt.tpe.baseType(defn.PartialFunctionClass).argInfos + val parent = defn.AbstractPartialFunctionType.appliedTo(tpArgs) + val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse)) + cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls) } private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe.dealias match { diff --git a/tests/pos/i4177.scala b/tests/pos/i4177.scala index 961e1e39b289..e4ec7dac3745 100644 --- a/tests/pos/i4177.scala +++ b/tests/pos/i4177.scala @@ -11,5 +11,10 @@ class Test { val e: PartialFunction[String, String] = { case x @ "abc" => x } val f: PartialFunction[String, String] = x => x match { case "abc" => x } val g: PartialFunction[String, String] = x => x match { case "abc" if x.isEmpty => x } + + type P = PartialFunction[String,String] + val h: P = { case x => x.toString } + + val i: PartialFunction[Int, Int] = { x => x match { case x => x } } } } From c8904c5e09ab4173ca4ec5ad96fc55917c020a55 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 25 Apr 2018 15:33:17 +0200 Subject: [PATCH 3/5] Polishing --- .../dotty/tools/dotc/transform/ExpandSAMs.scala | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index b6faab4a28c5..80a82a7d31c6 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -35,8 +35,8 @@ class ExpandSAMs extends MiniPhase { tpt.tpe match { case NoType => tree // it's a plain function case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) => - checkRefinements(tpe, fn.pos) - toPartialFunction(tree) + val tpe1 = checkRefinements(tpe, fn.pos) + toPartialFunction(tree, tpe1) case tpe @ SAMType(_) if isPlatformSam(tpe.classSymbol.asClass) => checkRefinements(tpe, fn.pos) tree @@ -50,10 +50,9 @@ class ExpandSAMs extends MiniPhase { tree } - private def toPartialFunction(tree: Block)(implicit ctx: Context): Tree = { + private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = { val Block( - (applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil, - Closure(_, _, tpt)) = tree + (applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil, _) = tree def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = { assert(tree.selector.symbol == param.symbol) @@ -78,7 +77,7 @@ class ExpandSAMs extends MiniPhase { def overrideSym(sym: Symbol) = sym.copy( owner = applyFn.owner, flags = Synthetic | Method | Final, - info = tpt.tpe.memberInfo(sym), + info = tpe.memberInfo(sym), coord = tree.pos).asTerm val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) @@ -115,8 +114,7 @@ class ExpandSAMs extends MiniPhase { val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_))) - val tpArgs = tpt.tpe.baseType(defn.PartialFunctionClass).argInfos - val parent = defn.AbstractPartialFunctionType.appliedTo(tpArgs) + val parent = defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos) val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse)) cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls) } @@ -126,7 +124,7 @@ class ExpandSAMs extends MiniPhase { if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement ctx.error("Lambda does not define " + name, pos) checkRefinements(parent, pos) - case _ => + case tpe => tpe } From d03b226b158d5ca0a000a98bbff7541e785f62b9 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 25 Apr 2018 18:50:51 +0200 Subject: [PATCH 4/5] Fix #4241: Follow Scalac for partial function literals - `x => x match { case x => x }` is a PF - `x => { x match { case x => x } }` is a PF - `x => { println("foo"); x match { case x => x } }` is not a PF - `x => x` is not a PF --- .../tools/dotc/transform/ExpandSAMs.scala | 113 +++++++++--------- tests/neg/i4241.scala | 12 ++ tests/pos/i4177.scala | 2 - 3 files changed, 70 insertions(+), 57 deletions(-) create mode 100644 tests/neg/i4241.scala diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index 80a82a7d31c6..dfc21017f447 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -8,6 +8,7 @@ import MegaPhase._ import SymUtils._ import ast.untpd import ast.Trees._ +import dotty.tools.dotc.reporting.diagnostic.messages.TypeMismatch import dotty.tools.dotc.util.Positions.Position /** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes. @@ -51,72 +52,74 @@ class ExpandSAMs extends MiniPhase { } private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = { - val Block( - (applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil, _) = tree - - def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = { - assert(tree.selector.symbol == param.symbol) - val selectorTpe = selector.tpe.widen - val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD, Synthetic, selectorTpe) - val defaultCase = - CaseDef( - Bind(defaultSym, Underscore(selectorTpe)), - EmptyTree, - defaultValue) - val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType))) - cpy.Match(tree)(unchecked, cases :+ defaultCase) - .subst(param.symbol :: Nil, selector.symbol :: Nil) - // Needed because a partial function can be written as: - // param => param match { case "foo" if foo(param) => param } - // And we need to update all references to 'param' + // /** An extractor for match, either contained in a block or standalone. */ + object PartialFunctionRHS { + def unapply(tree: Tree): Option[Match] = tree match { + case Block(Nil, expr) => unapply(expr) + case m: Match => Some(m) + case _ => None + } } - val applyRhs = applyDef.rhs - val applyFn = applyDef.symbol - - def overrideSym(sym: Symbol) = sym.copy( - owner = applyFn.owner, - flags = Synthetic | Method | Final, - info = tpe.memberInfo(sym), - coord = tree.pos).asTerm - val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) - val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) - - def isDefinedAtRhs(paramRefss: List[List[Tree]]) = { - val tru = Literal(Constant(true)) - applyRhs match { - case tree @ Match(_, cases) => + val closureDef(anon @ DefDef(_, _, List(List(param)), _, _)) = tree + anon.rhs match { + case PartialFunctionRHS(pf) => + val anonSym = anon.symbol + + def overrideSym(sym: Symbol) = sym.copy( + owner = anonSym.owner, + flags = Synthetic | Method | Final, + info = tpe.memberInfo(sym), + coord = tree.pos).asTerm + val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) + val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) + + def isDefinedAtRhs(paramRefss: List[List[Tree]]) = { + val tru = Literal(Constant(true)) def translateCase(cdef: CaseDef) = - cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn) + cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn) val paramRef = paramRefss.head.head val defaultValue = Literal(Constant(false)) - translateMatch(tree, paramRef, cases.map(translateCase), defaultValue) - case _ => - tru - } - } + translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue) + } - def applyOrElseRhs(paramRefss: List[List[Tree]]) = { - val List(paramRef, defaultRef) = paramRefss.head - applyRhs match { - case tree @ Match(_, cases) => + def applyOrElseRhs(paramRefss: List[List[Tree]]) = { + val List(paramRef, defaultRef) = paramRefss.head def translateCase(cdef: CaseDef) = - cdef.changeOwner(applyFn, applyOrElseFn) + cdef.changeOwner(anonSym, applyOrElseFn) val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef) - translateMatch(tree, paramRef, cases.map(translateCase), defaultValue) - case _ => - applyRhs - .changeOwner(applyFn, applyOrElseFn) - .subst(param.symbol :: Nil, paramRef.symbol :: Nil) + translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue) + } + + def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = { + assert(tree.selector.symbol == param.symbol) + val selectorTpe = selector.tpe.widen + val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD, Synthetic, selectorTpe) + val defaultCase = + CaseDef( + Bind(defaultSym, Underscore(selectorTpe)), + EmptyTree, + defaultValue) + val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType))) + cpy.Match(tree)(unchecked, cases :+ defaultCase) + .subst(param.symbol :: Nil, selector.symbol :: Nil) + // Needed because a partial function can be written as: + // param => param match { case "foo" if foo(param) => param } + // And we need to update all references to 'param' } - } - val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) - val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_))) + val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) + val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_))) - val parent = defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos) - val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse)) - cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls) + val parent = defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos) + val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse)) + cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls) + + case _ => + val found = tpe.baseType(defn.FunctionClass(1)) + ctx.error(TypeMismatch(found, tpe), tree.pos) + tree + } } private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe.dealias match { diff --git a/tests/neg/i4241.scala b/tests/neg/i4241.scala new file mode 100644 index 000000000000..3d93a44a015a --- /dev/null +++ b/tests/neg/i4241.scala @@ -0,0 +1,12 @@ +class Test { + def test: Unit = { + val a: PartialFunction[Int, Int] = { case x => x } + val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case _ => 2 } + val c: PartialFunction[Int, Int] = x => { x match { case y => y } } + val d: PartialFunction[Int, Int] = x => { { x match { case y => y } } } + + val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case y => y } } // error + val f: PartialFunction[Int, Int] = x => x // error + val g: PartialFunction[Int, String] = { x => x.toString } // error + } +} diff --git a/tests/pos/i4177.scala b/tests/pos/i4177.scala index e4ec7dac3745..dfcedf92d424 100644 --- a/tests/pos/i4177.scala +++ b/tests/pos/i4177.scala @@ -5,8 +5,6 @@ class Test { def test: Unit = { val a: PartialFunction[Int, String] = { case Foo(x) => x } val b: PartialFunction[Int, String] = { case x => x.toString } - val c: PartialFunction[Int, String] = { x => x.toString } - val d: PartialFunction[Int, String] = x => x.toString val e: PartialFunction[String, String] = { case x @ "abc" => x } val f: PartialFunction[String, String] = x => x match { case "abc" => x } From 8a91774e2c9546809e018b26ef2c9af8dc9769df Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Thu, 26 Apr 2018 16:36:34 +0200 Subject: [PATCH 5/5] Fix partial functions with non trivial selector E.g. ``` class Foo(val field: Option[Int]) val p: PartialFunction[Foo, Int] = foo => foo.field match { case Some(x) => x } ``` --- .../tools/dotc/transform/ExpandSAMs.scala | 38 +++++++++---------- tests/run/partialFunctions.scala | 11 ++++-- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index dfc21017f447..e2c04b920250 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -74,13 +74,30 @@ class ExpandSAMs extends MiniPhase { val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) + def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = { + val selector = tree.selector + val selectorTpe = selector.tpe.widen + val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe) + val defaultCase = + CaseDef( + Bind(defaultSym, Underscore(selectorTpe)), + EmptyTree, + defaultValue) + val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType))) + cpy.Match(tree)(unchecked, cases :+ defaultCase) + .subst(param.symbol :: Nil, pfParam :: Nil) + // Needed because a partial function can be written as: + // param => param match { case "foo" if foo(param) => param } + // And we need to update all references to 'param' + } + def isDefinedAtRhs(paramRefss: List[List[Tree]]) = { val tru = Literal(Constant(true)) def translateCase(cdef: CaseDef) = cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn) val paramRef = paramRefss.head.head val defaultValue = Literal(Constant(false)) - translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue) + translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue) } def applyOrElseRhs(paramRefss: List[List[Tree]]) = { @@ -88,24 +105,7 @@ class ExpandSAMs extends MiniPhase { def translateCase(cdef: CaseDef) = cdef.changeOwner(anonSym, applyOrElseFn) val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef) - translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue) - } - - def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = { - assert(tree.selector.symbol == param.symbol) - val selectorTpe = selector.tpe.widen - val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD, Synthetic, selectorTpe) - val defaultCase = - CaseDef( - Bind(defaultSym, Underscore(selectorTpe)), - EmptyTree, - defaultValue) - val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType))) - cpy.Match(tree)(unchecked, cases :+ defaultCase) - .subst(param.symbol :: Nil, selector.symbol :: Nil) - // Needed because a partial function can be written as: - // param => param match { case "foo" if foo(param) => param } - // And we need to update all references to 'param' + translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue) } val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) diff --git a/tests/run/partialFunctions.scala b/tests/run/partialFunctions.scala index ca82431ca8e8..8120b3fa886b 100644 --- a/tests/run/partialFunctions.scala +++ b/tests/run/partialFunctions.scala @@ -1,10 +1,15 @@ object Test { - def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1) + def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1) + class Foo(val field: Option[Int]) def main(args: Array[String]): Unit = { - val partialFunction: PartialFunction[Int, Int] = {case a: Int => a} + val p1: PartialFunction[Int, Int] = { case a: Int => a } + assert(takesPartialFunction(p1) == 1) - assert(takesPartialFunction(partialFunction) == 1) + val p2: PartialFunction[Foo, Int] = + foo => foo.field match { case Some(x) => x } + assert(p2.isDefinedAt(new Foo(Some(1)))) + assert(!p2.isDefinedAt(new Foo(None))) } }