Skip to content

Reorganizing API and adding conversions #136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 104 additions & 92 deletions core/3.2/src/main/scala/org/apache/spark/sql/KotlinWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,189 +26,201 @@ import org.apache.spark.sql.types.{DataType, Metadata, StructField, StructType}


trait DataTypeWithClass {
val dt: DataType
val cls: Class[_]
val nullable: Boolean
val dt: DataType
val cls: Class[ _ ]
val nullable: Boolean
}

trait ComplexWrapper extends DataTypeWithClass

class KDataTypeWrapper(val dt: StructType
, val cls: Class[_]
, val nullable: Boolean = true) extends StructType with ComplexWrapper {
override def fieldNames: Array[String] = dt.fieldNames
class KDataTypeWrapper(
val dt: StructType,
val cls: Class[ _ ],
val nullable: Boolean = true,
) extends StructType with ComplexWrapper {

override def names: Array[String] = dt.names
override def fieldNames: Array[ String ] = dt.fieldNames

override def equals(that: Any): Boolean = dt.equals(that)
override def names: Array[ String ] = dt.names

override def hashCode(): Int = dt.hashCode()
override def equals(that: Any): Boolean = dt.equals(that)

override def add(field: StructField): StructType = dt.add(field)
override def hashCode(): Int = dt.hashCode()

override def add(name: String, dataType: DataType): StructType = dt.add(name, dataType)
override def add(field: StructField): StructType = dt.add(field)

override def add(name: String, dataType: DataType, nullable: Boolean): StructType = dt.add(name, dataType, nullable)
override def add(name: String, dataType: DataType): StructType = dt.add(name, dataType)

override def add(name: String, dataType: DataType, nullable: Boolean, metadata: Metadata): StructType = dt.add(name, dataType, nullable, metadata)
override def add(name: String, dataType: DataType, nullable: Boolean): StructType = dt.add(name, dataType, nullable)

override def add(name: String, dataType: DataType, nullable: Boolean, comment: String): StructType = dt.add(name, dataType, nullable, comment)
override def add(name: String, dataType: DataType, nullable: Boolean, metadata: Metadata): StructType = dt
.add(name, dataType, nullable, metadata)

override def add(name: String, dataType: String): StructType = dt.add(name, dataType)
override def add(name: String, dataType: DataType, nullable: Boolean, comment: String): StructType = dt
.add(name, dataType, nullable, comment)

override def add(name: String, dataType: String, nullable: Boolean): StructType = dt.add(name, dataType, nullable)
override def add(name: String, dataType: String): StructType = dt.add(name, dataType)

override def add(name: String, dataType: String, nullable: Boolean, metadata: Metadata): StructType = dt.add(name, dataType, nullable, metadata)
override def add(name: String, dataType: String, nullable: Boolean): StructType = dt.add(name, dataType, nullable)

override def add(name: String, dataType: String, nullable: Boolean, comment: String): StructType = dt.add(name, dataType, nullable, comment)
override def add(name: String, dataType: String, nullable: Boolean, metadata: Metadata): StructType = dt
.add(name, dataType, nullable, metadata)

override def apply(name: String): StructField = dt.apply(name)
override def add(name: String, dataType: String, nullable: Boolean, comment: String): StructType = dt
.add(name, dataType, nullable, comment)

override def apply(names: Set[String]): StructType = dt.apply(names)
override def apply(name: String): StructField = dt.apply(name)

override def fieldIndex(name: String): Int = dt.fieldIndex(name)
override def apply(names: Set[ String ]): StructType = dt.apply(names)

override private[sql] def getFieldIndex(name: String) = dt.getFieldIndex(name)
override def fieldIndex(name: String): Int = dt.fieldIndex(name)

private[sql] def findNestedField(fieldNames: Seq[String], includeCollections: Boolean, resolver: Resolver) = dt.findNestedField(fieldNames, includeCollections, resolver)
override private[ sql ] def getFieldIndex(name: String) = dt.getFieldIndex(name)

override private[sql] def buildFormattedString(prefix: String, stringConcat: StringUtils.StringConcat, maxDepth: Int): Unit = dt.buildFormattedString(prefix, stringConcat, maxDepth)
private[ sql ] def findNestedField(fieldNames: Seq[ String ], includeCollections: Boolean, resolver: Resolver) =
dt.findNestedField(fieldNames, includeCollections, resolver)

override protected[sql] def toAttributes: Seq[AttributeReference] = dt.toAttributes
override private[ sql ] def buildFormattedString(prefix: String, stringConcat: StringUtils.StringConcat, maxDepth: Int): Unit =
dt.buildFormattedString(prefix, stringConcat, maxDepth)

override def treeString: String = dt.treeString
override protected[ sql ] def toAttributes: Seq[ AttributeReference ] = dt.toAttributes

override def treeString(maxDepth: Int): String = dt.treeString(maxDepth)
override def treeString: String = dt.treeString

override def printTreeString(): Unit = dt.printTreeString()
override def treeString(maxDepth: Int): String = dt.treeString(maxDepth)

private[sql] override def jsonValue = dt.jsonValue
override def printTreeString(): Unit = dt.printTreeString()

override def apply(fieldIndex: Int): StructField = dt.apply(fieldIndex)
private[ sql ] override def jsonValue = dt.jsonValue

override def length: Int = dt.length
override def apply(fieldIndex: Int): StructField = dt.apply(fieldIndex)

override def iterator: Iterator[StructField] = dt.iterator
override def length: Int = dt.length

override def defaultSize: Int = dt.defaultSize
override def iterator: Iterator[ StructField ] = dt.iterator

override def simpleString: String = dt.simpleString
override def defaultSize: Int = dt.defaultSize

override def catalogString: String = dt.catalogString
override def simpleString: String = dt.simpleString

override def sql: String = dt.sql
override def catalogString: String = dt.catalogString

override def toDDL: String = dt.toDDL
override def sql: String = dt.sql

private[sql] override def simpleString(maxNumberFields: Int) = dt.simpleString(maxNumberFields)
override def toDDL: String = dt.toDDL

override private[sql] def merge(that: StructType) = dt.merge(that)
private[ sql ] override def simpleString(maxNumberFields: Int) = dt.simpleString(maxNumberFields)

private[spark] override def asNullable = dt.asNullable
override private[ sql ] def merge(that: StructType) = dt.merge(that)

private[spark] override def existsRecursively(f: DataType => Boolean) = dt.existsRecursively(f)
private[ spark ] override def asNullable = dt.asNullable

override private[sql] lazy val interpretedOrdering = dt.interpretedOrdering
private[ spark ] override def existsRecursively(f: DataType => Boolean) = dt.existsRecursively(f)

override def toString = s"KDataTypeWrapper(dt=$dt, cls=$cls, nullable=$nullable)"
override private[ sql ] lazy val interpretedOrdering = dt.interpretedOrdering

override def toString = s"KDataTypeWrapper(dt=$dt, cls=$cls, nullable=$nullable)"
}

case class KComplexTypeWrapper(dt: DataType, cls: Class[_], nullable: Boolean) extends DataType with ComplexWrapper {
override private[sql] def unapply(e: Expression) = dt.unapply(e)
case class KComplexTypeWrapper(dt: DataType, cls: Class[ _ ], nullable: Boolean) extends DataType with ComplexWrapper {

override private[ sql ] def unapply(e: Expression) = dt.unapply(e)

override def typeName: String = dt.typeName
override def typeName: String = dt.typeName

override private[sql] def jsonValue = dt.jsonValue
override private[ sql ] def jsonValue = dt.jsonValue

override def json: String = dt.json
override def json: String = dt.json

override def prettyJson: String = dt.prettyJson
override def prettyJson: String = dt.prettyJson

override def simpleString: String = dt.simpleString
override def simpleString: String = dt.simpleString

override def catalogString: String = dt.catalogString
override def catalogString: String = dt.catalogString

override private[sql] def simpleString(maxNumberFields: Int) = dt.simpleString(maxNumberFields)
override private[ sql ] def simpleString(maxNumberFields: Int) = dt.simpleString(maxNumberFields)

override def sql: String = dt.sql
override def sql: String = dt.sql

override private[spark] def sameType(other: DataType) = dt.sameType(other)
override private[ spark ] def sameType(other: DataType) = dt.sameType(other)

override private[spark] def existsRecursively(f: DataType => Boolean) = dt.existsRecursively(f)
override private[ spark ] def existsRecursively(f: DataType => Boolean) = dt.existsRecursively(f)

private[sql] override def defaultConcreteType = dt.defaultConcreteType
private[ sql ] override def defaultConcreteType = dt.defaultConcreteType

private[sql] override def acceptsType(other: DataType) = dt.acceptsType(other)
private[ sql ] override def acceptsType(other: DataType) = dt.acceptsType(other)

override def defaultSize: Int = dt.defaultSize
override def defaultSize: Int = dt.defaultSize

override private[spark] def asNullable = dt.asNullable
override private[ spark ] def asNullable = dt.asNullable

}

case class KSimpleTypeWrapper(dt: DataType, cls: Class[_], nullable: Boolean) extends DataType with DataTypeWithClass {
override private[sql] def unapply(e: Expression) = dt.unapply(e)
case class KSimpleTypeWrapper(dt: DataType, cls: Class[ _ ], nullable: Boolean) extends DataType with DataTypeWithClass {
override private[ sql ] def unapply(e: Expression) = dt.unapply(e)

override def typeName: String = dt.typeName
override def typeName: String = dt.typeName

override private[sql] def jsonValue = dt.jsonValue
override private[ sql ] def jsonValue = dt.jsonValue

override def json: String = dt.json
override def json: String = dt.json

override def prettyJson: String = dt.prettyJson
override def prettyJson: String = dt.prettyJson

override def simpleString: String = dt.simpleString
override def simpleString: String = dt.simpleString

override def catalogString: String = dt.catalogString
override def catalogString: String = dt.catalogString

override private[sql] def simpleString(maxNumberFields: Int) = dt.simpleString(maxNumberFields)
override private[ sql ] def simpleString(maxNumberFields: Int) = dt.simpleString(maxNumberFields)

override def sql: String = dt.sql
override def sql: String = dt.sql

override private[spark] def sameType(other: DataType) = dt.sameType(other)
override private[ spark ] def sameType(other: DataType) = dt.sameType(other)

override private[spark] def existsRecursively(f: DataType => Boolean) = dt.existsRecursively(f)
override private[ spark ] def existsRecursively(f: DataType => Boolean) = dt.existsRecursively(f)

private[sql] override def defaultConcreteType = dt.defaultConcreteType
private[ sql ] override def defaultConcreteType = dt.defaultConcreteType

private[sql] override def acceptsType(other: DataType) = dt.acceptsType(other)
private[ sql ] override def acceptsType(other: DataType) = dt.acceptsType(other)

override def defaultSize: Int = dt.defaultSize
override def defaultSize: Int = dt.defaultSize

override private[spark] def asNullable = dt.asNullable
override private[ spark ] def asNullable = dt.asNullable
}

class KStructField(val getterName: String, val delegate: StructField) extends StructField {
override private[sql] def buildFormattedString(prefix: String, stringConcat: StringUtils.StringConcat, maxDepth: Int): Unit = delegate.buildFormattedString(prefix, stringConcat, maxDepth)

override def toString(): String = delegate.toString()
override private[ sql ] def buildFormattedString(prefix: String, stringConcat: StringUtils.StringConcat, maxDepth: Int): Unit =
delegate.buildFormattedString(prefix, stringConcat, maxDepth)

override def toString(): String = delegate.toString()

override private[sql] def jsonValue = delegate.jsonValue
override private[ sql ] def jsonValue = delegate.jsonValue

override def withComment(comment: String): StructField = delegate.withComment(comment)
override def withComment(comment: String): StructField = delegate.withComment(comment)

override def getComment(): Option[String] = delegate.getComment()
override def getComment(): Option[ String ] = delegate.getComment()

override def toDDL: String = delegate.toDDL
override def toDDL: String = delegate.toDDL

override def productElement(n: Int): Any = delegate.productElement(n)
override def productElement(n: Int): Any = delegate.productElement(n)

override def productArity: Int = delegate.productArity
override def productArity: Int = delegate.productArity

override def productIterator: Iterator[Any] = delegate.productIterator
override def productIterator: Iterator[ Any ] = delegate.productIterator

override def productPrefix: String = delegate.productPrefix
override def productPrefix: String = delegate.productPrefix

override val dataType: DataType = delegate.dataType
override val dataType: DataType = delegate.dataType

override def canEqual(that: Any): Boolean = delegate.canEqual(that)
override def canEqual(that: Any): Boolean = delegate.canEqual(that)

override val metadata: Metadata = delegate.metadata
override val name: String = delegate.name
override val nullable: Boolean = delegate.nullable
override val metadata: Metadata = delegate.metadata
override val name: String = delegate.name
override val nullable: Boolean = delegate.nullable
}

object helpme {

def listToSeq(i: java.util.List[_]): Seq[_] = Seq(i.toArray: _*)
def listToSeq(i: java.util.List[ _ ]): Seq[ _ ] = Seq(i.toArray: _*)
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ data class SomeClass(val a: IntArray, val b: Int) : Serializable
fun main() = withSpark {
val broadcastVariable = spark.broadcast(SomeClass(a = intArrayOf(5, 6), b = 3))
val result = listOf(1, 2, 3, 4, 5)
.toDS()
.map {
val receivedBroadcast = broadcastVariable.value
it + receivedBroadcast.a.first()
}
.collectAsList()
.toDS()
.map {
val receivedBroadcast = broadcastVariable.value
it + receivedBroadcast.a.first()
}
.collectAsList()

println(result)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ import org.jetbrains.kotlinx.spark.api.*
fun main() {
withSpark {
dsOf(1, 2, 3, 4, 5)
.map { it to (it + 2) }
.withCached {
showDS()
.map { it to (it + 2) }
.withCached {
showDS()

filter { it.first % 2 == 0 }.showDS()
}
.map { c(it.first, it.second, (it.first + it.second) * 2) }
.show()
filter { it.first % 2 == 0 }.showDS()
}
.map { c(it.first, it.second, (it.first + it.second) * 2) }
.show()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ import org.jetbrains.kotlinx.spark.api.*

fun main() {
withSpark(props = mapOf("spark.sql.codegen.wholeStage" to true)) {
dsOf(mapOf(1 to c(1, 2, 3), 2 to c(1, 2, 3)), mapOf(3 to c(1, 2, 3), 4 to c(1, 2, 3)))
.flatMap { it.toList().map { p -> listOf(p.first, p.second._1, p.second._2, p.second._3) }.iterator() }
.flatten()
.map { c(it) }
.also { it.printSchema() }
.distinct()
.sort("_1")
.debugCodegen()
.show()
dsOf(
mapOf(1 to c(1, 2, 3), 2 to c(1, 2, 3)),
mapOf(3 to c(1, 2, 3), 4 to c(1, 2, 3)),
)
.flatMap { it.toList().map { p -> listOf(p.first, p.second._1, p.second._2, p.second._3) }.iterator() }
.flatten()
.map { c(it) }
.also { it.printSchema() }
.distinct()
.sort("_1")
.debugCodegen()
.show()
}
}

Loading