|
| 1 | +package test |
| 2 | + |
| 3 | +import scala.quoted.* |
| 4 | + |
| 5 | +trait Thing { |
| 6 | + type Type |
| 7 | +} |
| 8 | + |
| 9 | +object MyMacro { |
| 10 | + |
| 11 | + def isExpectedReturnType[R: Type](using Quotes): quotes.reflect.Symbol => Boolean = { method => |
| 12 | + import quotes.reflect.* |
| 13 | + |
| 14 | + val expectedReturnType = TypeRepr.of[R] |
| 15 | + |
| 16 | + method.tree match { |
| 17 | + case DefDef(_,_,typedTree,_) => |
| 18 | + TypeRepr.of(using typedTree.tpe.asType) <:< expectedReturnType |
| 19 | + case _ => false |
| 20 | + } |
| 21 | + } |
| 22 | + |
| 23 | + ///TODO no overloads |
| 24 | + def checkMethod[R: Type](using q: Quotes)(method: quotes.reflect.Symbol): Option[String] = { |
| 25 | + val isExpectedReturnTypeFun = isExpectedReturnType[R] |
| 26 | + |
| 27 | + Option.when(method.paramSymss.headOption.exists(_.exists(_.isType)))(s"Method ${method.name} has a generic type parameter, this is not supported") orElse |
| 28 | + Option.when(!isExpectedReturnTypeFun(method))(s"Method ${method.name} has unexpected return type") |
| 29 | + } |
| 30 | + |
| 31 | + def definedMethodsInType[T: Type](using Quotes): List[quotes.reflect.Symbol] = { |
| 32 | + import quotes.reflect.* |
| 33 | + |
| 34 | + val tree = TypeTree.of[T] |
| 35 | + |
| 36 | + for { |
| 37 | + member <- tree.symbol.methodMembers |
| 38 | + //is abstract method, not implemented |
| 39 | + if member.flags.is(Flags.Deferred) |
| 40 | + |
| 41 | + //TODO: is that public? |
| 42 | + // TODO? if member.privateWithin |
| 43 | + if !member.flags.is(Flags.Private) |
| 44 | + if !member.flags.is(Flags.Protected) |
| 45 | + if !member.flags.is(Flags.PrivateLocal) |
| 46 | + |
| 47 | + if !member.isClassConstructor |
| 48 | + if !member.flags.is(Flags.Synthetic) |
| 49 | + } yield member |
| 50 | + } |
| 51 | + |
| 52 | + transparent inline def client[T, R](r: () => R): T = ${MyMacro.clientImpl[T, R]('r)} |
| 53 | + |
| 54 | + def clientImpl[T: Type, R: Type](r: Expr[() => R])(using Quotes): Expr[T] = { |
| 55 | + import quotes.reflect.* |
| 56 | + |
| 57 | + val apiType = TypeRepr.of[T] |
| 58 | + val tree = TypeTree.of[T] |
| 59 | + |
| 60 | + val methods = definedMethodsInType[T] |
| 61 | + val invalidMethods = methods.flatMap(checkMethod[R]) |
| 62 | + if (invalidMethods.nonEmpty) { |
| 63 | + report.errorAndAbort(s"Invalid methods: ${invalidMethods.mkString(", ")}") |
| 64 | + } |
| 65 | + |
| 66 | + val className = "_Anon" |
| 67 | + val parents = List(TypeTree.of[Object], TypeTree.of[T]) |
| 68 | + |
| 69 | + def decls(cls: Symbol): List[Symbol] = methods.map { method => |
| 70 | + Symbol.newMethod(cls, method.name, method.info, flags = Flags.EmptyFlags /*TODO: method.flags */, privateWithin = method.privateWithin.fold(Symbol.noSymbol)(_.typeSymbol)) |
| 71 | + } |
| 72 | + |
| 73 | + val cls = Symbol.newClass(Symbol.spliceOwner, className, parents = parents.map(_.tpe), decls, selfType = None) |
| 74 | + val body = cls.declaredMethods.map { method => DefDef(method, argss => Some('{${r}()}.asTerm)) } |
| 75 | + val clsDef = ClassDef(cls, parents, body = body) |
| 76 | + val newCls = Typed(Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil), TypeTree.of[T]) |
| 77 | + Block(List(clsDef), newCls).asExprOf[T] |
| 78 | + } |
| 79 | +} |
0 commit comments