diff --git a/README.md b/README.md index e2a24e29..843cfbb9 100644 --- a/README.md +++ b/README.md @@ -271,7 +271,7 @@ For more information, check the [wiki](https://github.com/JetBrains/kotlin-spark ## Examples -For more, check out [examples](https://github.com/JetBrains/kotlin-spark-api/tree/master/examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples) module. +For more, check out [examples](examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples) module. To get up and running quickly, check out this [tutorial](https://github.com/JetBrains/kotlin-spark-api/wiki/Quick-Start-Guide). ## Reporting issues/Support diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt index 0b2a8306..24ae04ce 100644 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt @@ -22,6 +22,7 @@ package org.jetbrains.kotlinx.spark.api.jupyter import org.apache.spark.api.java.JavaRDDLike import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.jetbrains.kotlinx.jupyter.api.FieldValue import org.jetbrains.kotlinx.jupyter.api.HTML import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration @@ -33,50 +34,71 @@ abstract class Integration : JupyterIntegration() { private val scalaVersion = "2.12.15" private val spark3Version = "3.2.1" + /** + * Will be run after importing all dependencies + */ abstract fun KotlinKernelHost.onLoaded() - override fun Builder.onLoaded() { + abstract fun KotlinKernelHost.onShutdown() + + abstract fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) + + open val dependencies: Array = arrayOf( + "org.apache.spark:spark-repl_$scalaCompatVersion:$spark3Version", + "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion", + "org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion", + "org.apache.spark:spark-sql_$scalaCompatVersion:$spark3Version", + "org.apache.spark:spark-streaming_$scalaCompatVersion:$spark3Version", + "org.apache.spark:spark-mllib_$scalaCompatVersion:$spark3Version", + "org.apache.spark:spark-sql_$scalaCompatVersion:$spark3Version", + "org.apache.spark:spark-graphx_$scalaCompatVersion:$spark3Version", + "org.apache.spark:spark-launcher_$scalaCompatVersion:$spark3Version", + "org.apache.spark:spark-catalyst_$scalaCompatVersion:$spark3Version", + "org.apache.spark:spark-streaming_$scalaCompatVersion:$spark3Version", + "org.apache.spark:spark-core_$scalaCompatVersion:$spark3Version", + "org.scala-lang:scala-library:$scalaVersion", + "org.scala-lang.modules:scala-xml_$scalaCompatVersion:2.0.1", + "org.scala-lang:scala-reflect:$scalaVersion", + "org.scala-lang:scala-compiler:$scalaVersion", + "commons-io:commons-io:2.11.0", + ) - dependencies( - "org.apache.spark:spark-repl_$scalaCompatVersion:$spark3Version", - "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion", - "org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion", - "org.apache.spark:spark-sql_$scalaCompatVersion:$spark3Version", - "org.apache.spark:spark-streaming_$scalaCompatVersion:$spark3Version", - "org.apache.spark:spark-mllib_$scalaCompatVersion:$spark3Version", - "org.apache.spark:spark-sql_$scalaCompatVersion:$spark3Version", - "org.apache.spark:spark-graphx_$scalaCompatVersion:$spark3Version", - "org.apache.spark:spark-launcher_$scalaCompatVersion:$spark3Version", - "org.apache.spark:spark-catalyst_$scalaCompatVersion:$spark3Version", - "org.apache.spark:spark-streaming_$scalaCompatVersion:$spark3Version", - "org.apache.spark:spark-core_$scalaCompatVersion:$spark3Version", - "org.scala-lang:scala-library:$scalaVersion", - "org.scala-lang.modules:scala-xml_$scalaCompatVersion:2.0.1", - "org.scala-lang:scala-reflect:$scalaVersion", - "org.scala-lang:scala-compiler:$scalaVersion", - "commons-io:commons-io:2.11.0", - ) - - import( - "org.jetbrains.kotlinx.spark.api.*", - "org.jetbrains.kotlinx.spark.api.tuples.*", - *(1..22).map { "scala.Tuple$it" }.toTypedArray(), - "org.apache.spark.sql.functions.*", - "org.apache.spark.*", - "org.apache.spark.sql.*", - "org.apache.spark.api.java.*", - "scala.collection.Seq", - "org.apache.spark.rdd.*", - "java.io.Serializable", - "org.apache.spark.streaming.api.java.*", - "org.apache.spark.streaming.api.*", - "org.apache.spark.streaming.*", - ) + open val imports: Array = arrayOf( + "org.jetbrains.kotlinx.spark.api.*", + "org.jetbrains.kotlinx.spark.api.tuples.*", + *(1..22).map { "scala.Tuple$it" }.toTypedArray(), + "org.apache.spark.sql.functions.*", + "org.apache.spark.*", + "org.apache.spark.sql.*", + "org.apache.spark.api.java.*", + "scala.collection.Seq", + "org.apache.spark.rdd.*", + "java.io.Serializable", + "org.apache.spark.streaming.api.java.*", + "org.apache.spark.streaming.api.*", + "org.apache.spark.streaming.*", + ) + + override fun Builder.onLoaded() { + dependencies(*dependencies) + import(*imports) onLoaded { onLoaded() } + beforeCellExecution { + execute("""scala.Console.setOut(System.out)""") + } + + afterCellExecution { snippetInstance, result -> + afterCellExecution(snippetInstance, result) + } + + onShutdown { + onShutdown() + } + // Render Dataset render> { HTML(it.toHtml()) diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt index 635ed654..a3ec6dc5 100644 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt @@ -21,6 +21,7 @@ package org.jetbrains.kotlinx.spark.api.jupyter import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.jupyter.api.FieldValue import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost /** @@ -68,4 +69,10 @@ internal class SparkIntegration : Integration() { val udf: UDFRegistration get() = spark.udf()""".trimIndent(), ).map(::execute) } + + override fun KotlinKernelHost.onShutdown() { + execute("""spark.stop()""") + } + + override fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) = Unit } diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt index 1684769b..4982830c 100644 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt @@ -19,27 +19,10 @@ */ package org.jetbrains.kotlinx.spark.api.jupyter -import kotlinx.html.* -import kotlinx.html.stream.appendHTML -import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dataset -import org.apache.spark.unsafe.array.ByteArrayMethods -import org.intellij.lang.annotations.Language -import org.jetbrains.kotlinx.jupyter.api.HTML -import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration -import org.jetbrains.kotlinx.spark.api.* -import java.io.InputStreamReader - -import org.apache.spark.* +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.jupyter.api.FieldValue import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost -import scala.collection.* -import org.jetbrains.kotlinx.spark.api.SparkSession -import scala.Product -import java.io.Serializable -import scala.collection.Iterable as ScalaIterable -import scala.collection.Iterator as ScalaIterator /** * %use spark-streaming @@ -48,6 +31,11 @@ import scala.collection.Iterator as ScalaIterator @OptIn(ExperimentalStdlibApi::class) internal class SparkStreamingIntegration : Integration() { + override val imports: Array = super.imports + arrayOf( + "org.apache.spark.deploy.SparkHadoopUtil", + "org.apache.hadoop.conf.Configuration", + ) + override fun KotlinKernelHost.onLoaded() { val _0 = execute("""%dumpClassesForSpark""") @@ -57,4 +45,8 @@ internal class SparkStreamingIntegration : Integration() { println("To start a spark streaming session, simply use `withSparkStreaming { }` inside a cell. To use Spark normally, use `withSpark { }` in a cell, or use `%use spark` to start a Spark session for the whole notebook.")""".trimIndent(), ).map(::execute) } + + override fun KotlinKernelHost.onShutdown() = Unit + + override fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) = Unit } diff --git a/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt b/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt index 2f35bee4..96d5d1fa 100644 --- a/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt +++ b/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt @@ -21,7 +21,6 @@ package org.jetbrains.kotlinx.spark.api.jupyter import io.kotest.assertions.throwables.shouldThrowAny import io.kotest.core.spec.style.ShouldSpec -import io.kotest.matchers.collections.shouldBeIn import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe @@ -29,7 +28,6 @@ import io.kotest.matchers.string.shouldContain import io.kotest.matchers.types.shouldBeInstanceOf import jupyter.kotlin.DependsOn import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.streaming.Duration import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.jupyter.EvalRequestData import org.jetbrains.kotlinx.jupyter.ReplForJupyter @@ -40,11 +38,8 @@ import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider import org.jetbrains.kotlinx.jupyter.repl.EvalResultEx import org.jetbrains.kotlinx.jupyter.testkit.ReplProvider import org.jetbrains.kotlinx.jupyter.util.PatternNameAcceptanceRule -import org.jetbrains.kotlinx.spark.api.tuples.* -import org.jetbrains.kotlinx.spark.api.* -import scala.Tuple2 +import org.jetbrains.kotlinx.spark.api.SparkSession import java.io.Serializable -import java.util.* import kotlin.script.experimental.jvm.util.classpathFromClassloader class JupyterTests : ShouldSpec({ @@ -155,16 +150,19 @@ class JupyterTests : ShouldSpec({ should("render JavaRDDs with custom class") { @Language("kts") - val klass = exec(""" + val klass = exec( + """ data class Test( val longFirstName: String, val second: LongArray, val somethingSpecial: Map, ): Serializable - """.trimIndent()) + """.trimIndent() + ) @Language("kts") - val html = execHtml(""" + val html = execHtml( + """ val rdd = sc.parallelize( listOf( Test("aaaaaaaaa", longArrayOf(1L, 100000L, 24L), mapOf(1 to "one", 2 to "two")), @@ -246,8 +244,10 @@ class JupyterStreamingTests : ShouldSpec({ host = this, integrationTypeNameRules = listOf( PatternNameAcceptanceRule(false, "org.jetbrains.kotlinx.spark.api.jupyter.**"), - PatternNameAcceptanceRule(true, - "org.jetbrains.kotlinx.spark.api.jupyter.SparkStreamingIntegration"), + PatternNameAcceptanceRule( + true, + "org.jetbrains.kotlinx.spark.api.jupyter.SparkStreamingIntegration" + ), ), ) } @@ -279,29 +279,46 @@ class JupyterStreamingTests : ShouldSpec({ } } - should("stream") { - val input = listOf("aaa", "bbb", "aaa", "ccc") - val counter = Counter(0) - - withSparkStreaming(Duration(10), timeout = 1000) { - - val (counterBroadcast, queue) = withSpark(ssc) { - spark.broadcast(counter) X LinkedList(listOf(sc.parallelize(input))) - } + xshould("stream") { - val inputStream = ssc.queueStream(queue) - - inputStream.foreachRDD { rdd, _ -> - withSpark(rdd) { - rdd.toDS().forEach { - it shouldBeIn input - counterBroadcast.value.value++ + @Language("kts") + val value = exec( + """ + import java.util.LinkedList + import org.apache.spark.api.java.function.ForeachFunction + import org.apache.spark.util.LongAccumulator + + + val input = arrayListOf("aaa", "bbb", "aaa", "ccc") + + @Volatile + var counter: LongAccumulator? = null + + withSparkStreaming(Duration(10), timeout = 1_000) { + + val queue = withSpark(ssc) { + LinkedList(listOf(sc.parallelize(input))) + } + + val inputStream = ssc.queueStream(queue) + + inputStream.foreachRDD { rdd, _ -> + withSpark(rdd) { + if (counter == null) + counter = sc.sc().longAccumulator() + + rdd.toDS().showDS().forEach { + if (it !in input) error(it + " should be in input") + counter!!.add(1L) + } } } } - } + counter!!.sum() + """.trimIndent() + ) as Long - counter.value shouldBe input.size + value shouldBe 4L } } diff --git a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt index 27513ffc..652e52b7 100644 --- a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt +++ b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt @@ -113,7 +113,7 @@ class KSparkStreamingSession(@Transient val ssc: JavaStreamingContext) : Seriali runAfterStart = block } - internal fun invokeRunAfterStart(): Unit = runAfterStart() + fun invokeRunAfterStart(): Unit = runAfterStart() /** Creates new spark session from given [sc]. */