Skip to content

Commit 165996f

Browse files
eason-yuchen-liuattilapiros
authored andcommitted
[SPARK-48589][SQL][SS] Add option snapshotStartBatchId and snapshotPartitionId to state data source
### What changes were proposed in this pull request? This PR defines two new options, snapshotStartBatchId and snapshotPartitionId, for the existing state reader. Both of them should be provided at the same time. 1. When there is no snapshot file at `snapshotStartBatch` (note there is an off-by-one issue between version and batch Id), throw an exception. 2. Otherwise, the reader should continue to rebuild the state by reading delta files only, and ignore all snapshot files afterwards. 3. Note that if a `batchId` option is already specified. That batchId is the ending batchId, we should then end at that batchId. 4. This feature supports state generated by HDFS state store provider and RocksDB state store provider with changelog checkpointing enabled. **It does not support RocksDB with changelog disabled which is the default for RocksDB.** ### Why are the changes needed? Sometimes when a snapshot is corrupted, users want to bypass it when reading a later state. This PR gives user ability to specify the starting snapshot version and partition. This feature can be useful for debugging purpose. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Created test cases for testing edge cases for the input of new options. Created test for the new public function `replayReadStateFromSnapshot`. Created integration test for the new options against four stateful operators: limit, aggregation, deduplication, stream-stream join. Instead of generating states within the tests which is unstable, I prepare golden files for the integration test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#46944 from eason-yuchen-liu/skipSnapshotAtBatch. Lead-authored-by: Yuchen Liu <[email protected]> Co-authored-by: Yuchen Liu <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 039609f commit 165996f

File tree

894 files changed

+1046
-24
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

894 files changed

+1046
-24
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@
236236
"Error reading delta file <fileToRead> of <clazz>: <fileToRead> does not exist."
237237
]
238238
},
239+
"CANNOT_READ_MISSING_SNAPSHOT_FILE" : {
240+
"message" : [
241+
"Error reading snapshot file <fileToRead> of <clazz>: <fileToRead> does not exist."
242+
]
243+
},
239244
"CANNOT_READ_SNAPSHOT_FILE_KEY_SIZE" : {
240245
"message" : [
241246
"Error reading snapshot file <fileToRead> of <clazz>: key size cannot be <keySize>."
@@ -251,6 +256,11 @@
251256
"Error reading streaming state file of <fileToRead> does not exist. If the stream job is restarted with a new or updated state operation, please create a new checkpoint location or clear the existing checkpoint location."
252257
]
253258
},
259+
"SNAPSHOT_PARTITION_ID_NOT_FOUND" : {
260+
"message" : [
261+
"Partition id <snapshotPartitionId> not found for state of operator <operatorId> at <checkpointLocation>."
262+
]
263+
},
254264
"UNCATEGORIZED" : {
255265
"message" : [
256266
""
@@ -3799,6 +3809,13 @@
37993809
],
38003810
"sqlState" : "42802"
38013811
},
3812+
"STATE_STORE_PROVIDER_DOES_NOT_SUPPORT_FINE_GRAINED_STATE_REPLAY" : {
3813+
"message" : [
3814+
"The given State Store Provider <inputClass> does not extend org.apache.spark.sql.execution.streaming.state.SupportsFineGrainedReplay.",
3815+
"Therefore, it does not support option snapshotStartBatchId in state data source."
3816+
],
3817+
"sqlState" : "42K06"
3818+
},
38023819
"STATE_STORE_UNSUPPORTED_OPERATION" : {
38033820
"message" : [
38043821
"<operationType> operation not supported with <entity>"

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,16 @@ case class StateSourceOptions(
116116
batchId: Long,
117117
operatorId: Int,
118118
storeName: String,
119-
joinSide: JoinSideValues) {
119+
joinSide: JoinSideValues,
120+
snapshotStartBatchId: Option[Long],
121+
snapshotPartitionId: Option[Int]) {
120122
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)
121123

122124
override def toString: String = {
123125
s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
124-
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide)"
126+
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
127+
s"snapshotStartBatchId=${snapshotStartBatchId.getOrElse("None")}, " +
128+
s"snapshotPartitionId=${snapshotPartitionId.getOrElse("None")})"
125129
}
126130
}
127131

@@ -131,6 +135,8 @@ object StateSourceOptions extends DataSourceOptions {
131135
val OPERATOR_ID = newOption("operatorId")
132136
val STORE_NAME = newOption("storeName")
133137
val JOIN_SIDE = newOption("joinSide")
138+
val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId")
139+
val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId")
134140

135141
object JoinSideValues extends Enumeration {
136142
type JoinSideValues = Value
@@ -190,7 +196,30 @@ object StateSourceOptions extends DataSourceOptions {
190196
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME))
191197
}
192198

193-
StateSourceOptions(resolvedCpLocation, batchId, operatorId, storeName, joinSide)
199+
val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong)
200+
if (snapshotStartBatchId.exists(_ < 0)) {
201+
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID)
202+
} else if (snapshotStartBatchId.exists(_ > batchId)) {
203+
throw StateDataSourceErrors.invalidOptionValue(
204+
SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId")
205+
}
206+
207+
val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt)
208+
if (snapshotPartitionId.exists(_ < 0)) {
209+
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID)
210+
}
211+
212+
// both snapshotPartitionId and snapshotStartBatchId are required at the same time, because
213+
// each partition may have different checkpoint status
214+
if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) {
215+
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID)
216+
} else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) {
217+
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID)
218+
}
219+
220+
StateSourceOptions(
221+
resolvedCpLocation, batchId, operatorId, storeName,
222+
joinSide, snapshotStartBatchId, snapshotPartitionId)
194223
}
195224

196225
private def resolvedCheckpointLocation(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
2323
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
2424
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
25-
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, ReadStateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
25+
import org.apache.spark.sql.execution.streaming.state._
2626
import org.apache.spark.sql.types.StructType
2727
import org.apache.spark.util.SerializableConfiguration
2828

@@ -93,7 +93,19 @@ class StatePartitionReader(
9393
}
9494

9595
private lazy val store: ReadStateStore = {
96-
provider.getReadStore(partition.sourceOptions.batchId + 1)
96+
partition.sourceOptions.snapshotStartBatchId match {
97+
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)
98+
99+
case Some(snapshotStartBatchId) =>
100+
if (!provider.isInstanceOf[SupportsFineGrainedReplay]) {
101+
throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
102+
provider.getClass.toString)
103+
}
104+
provider.asInstanceOf[SupportsFineGrainedReplay]
105+
.replayReadStateFromSnapshot(
106+
snapshotStartBatchId + 1,
107+
partition.sourceOptions.batchId + 1)
108+
}
97109
}
98110

99111
private lazy val iter: Iterator[InternalRow] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
2626
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
2727
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
2828
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
29-
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
29+
import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, StateStoreErrors}
3030
import org.apache.spark.sql.types.StructType
3131
import org.apache.spark.util.SerializableConfiguration
3232

@@ -81,9 +81,20 @@ class StateScan(
8181
assert((tail - head + 1) == partitionNums.length,
8282
s"No continuous partitions in state: ${partitionNums.mkString("Array(", ", ", ")")}")
8383

84-
partitionNums.map {
85-
pn => new StateStoreInputPartition(pn, queryId, sourceOptions)
86-
}.toArray
84+
sourceOptions.snapshotPartitionId match {
85+
case None => partitionNums.map { pn =>
86+
new StateStoreInputPartition(pn, queryId, sourceOptions)
87+
}.toArray
88+
89+
case Some(snapshotPartitionId) =>
90+
if (partitionNums.contains(snapshotPartitionId)) {
91+
Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions))
92+
} else {
93+
throw StateStoreErrors.stateStoreSnapshotPartitionNotFound(
94+
snapshotPartitionId, sourceOptions.operatorId,
95+
sourceOptions.stateCheckpointLocation.toString)
96+
}
97+
}
8798
}
8899
}
89100

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,21 @@ class StateTable(
4949
}
5050

5151
override def name(): String = {
52-
val desc = s"StateTable " +
52+
var desc = s"StateTable " +
5353
s"[stateCkptLocation=${sourceOptions.stateCheckpointLocation}]" +
5454
s"[batchId=${sourceOptions.batchId}][operatorId=${sourceOptions.operatorId}]" +
5555
s"[storeName=${sourceOptions.storeName}]"
5656

5757
if (sourceOptions.joinSide != JoinSideValues.none) {
58-
desc + s"[joinSide=${sourceOptions.joinSide}]"
59-
} else {
60-
desc
58+
desc += s"[joinSide=${sourceOptions.joinSide}]"
59+
}
60+
if (sourceOptions.snapshotStartBatchId.isDefined) {
61+
desc += s"[snapshotStartBatchId=${sourceOptions.snapshotStartBatchId}]"
62+
}
63+
if (sourceOptions.snapshotPartitionId.isDefined) {
64+
desc += s"[snapshotPartitionId=${sourceOptions.snapshotPartitionId}]"
6165
}
66+
desc
6267
}
6368

6469
override def capabilities(): util.Set[TableCapability] = CAPABILITY

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ class StreamStreamJoinStatePartitionReader(
116116
partitionId = partition.partition,
117117
formatVersion,
118118
skippedNullValueCount = None,
119-
useStateStoreCoordinator = false
119+
useStateStoreCoordinator = false,
120+
snapshotStartVersion = partition.sourceOptions.snapshotStartBatchId.map(_ + 1)
120121
)
121122
}
122123

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ import org.apache.spark.util.ArrayImplicits._
7171
* to ensure re-executed RDD operations re-apply updates on the correct past version of the
7272
* store.
7373
*/
74-
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging {
74+
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging
75+
with SupportsFineGrainedReplay {
7576

7677
private val providerName = "HDFSBackedStateStoreProvider"
7778

@@ -683,6 +684,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
683684
}
684685
}
685686

687+
/**
688+
* Try to read the snapshot file. If the snapshot file is not available, return [[None]].
689+
*
690+
* @param version the version of the snapshot file
691+
*/
686692
private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = {
687693
val fileToRead = snapshotFile(version)
688694
val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
@@ -883,4 +889,93 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
883889
throw new IllegalStateException(msg)
884890
}
885891
}
892+
893+
/**
894+
* Get the state store of endVersion by applying delta files on the snapshot of snapshotVersion.
895+
* If snapshot for snapshotVersion does not exist, an error will be thrown.
896+
*
897+
* @param snapshotVersion checkpoint version of the snapshot to start with
898+
* @param endVersion checkpoint version to end with
899+
* @return [[HDFSBackedStateStore]]
900+
*/
901+
override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = {
902+
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
903+
logInfo(log"Retrieved snapshot at version " +
904+
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
905+
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
906+
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for update")
907+
new HDFSBackedStateStore(endVersion, newMap)
908+
}
909+
910+
/**
911+
* Get the state store of endVersion for reading by applying delta files on the snapshot of
912+
* snapshotVersion. If snapshot for snapshotVersion does not exist, an error will be thrown.
913+
*
914+
* @param snapshotVersion checkpoint version of the snapshot to start with
915+
* @param endVersion checkpoint version to end with
916+
* @return [[HDFSBackedReadStateStore]]
917+
*/
918+
override def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long):
919+
ReadStateStore = {
920+
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
921+
logInfo(log"Retrieved snapshot at version " +
922+
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
923+
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
924+
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for read-only")
925+
new HDFSBackedReadStateStore(endVersion, newMap)
926+
}
927+
928+
/**
929+
* Construct the state map at endVersion from snapshot of version snapshotVersion.
930+
* Returns a new [[HDFSBackedStateStoreMap]]
931+
* @param snapshotVersion checkpoint version of the snapshot to start with
932+
* @param endVersion checkpoint version to end with
933+
*/
934+
private def replayLoadedMapFromSnapshot(snapshotVersion: Long, endVersion: Long):
935+
HDFSBackedStateStoreMap = synchronized {
936+
try {
937+
if (snapshotVersion < 1) {
938+
throw QueryExecutionErrors.unexpectedStateStoreVersion(snapshotVersion)
939+
}
940+
if (endVersion < snapshotVersion) {
941+
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
942+
}
943+
944+
val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
945+
newMap.putAll(constructMapFromSnapshot(snapshotVersion, endVersion))
946+
947+
newMap
948+
}
949+
catch {
950+
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
951+
}
952+
}
953+
954+
private def constructMapFromSnapshot(snapshotVersion: Long, endVersion: Long):
955+
HDFSBackedStateStoreMap = {
956+
val (result, elapsedMs) = Utils.timeTakenMs {
957+
val startVersionMap = synchronized { Option(loadedMaps.get(snapshotVersion)) } match {
958+
case Some(value) => Option(value)
959+
case None => readSnapshotFile(snapshotVersion)
960+
}
961+
if (startVersionMap.isEmpty) {
962+
throw StateStoreErrors.stateStoreSnapshotFileNotFound(
963+
snapshotFile(snapshotVersion).toString, toString())
964+
}
965+
966+
// Load all the deltas from the version after the start version up to the end version.
967+
val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
968+
resultMap.putAll(startVersionMap.get)
969+
for (deltaVersion <- snapshotVersion + 1 to endVersion) {
970+
updateFromDeltaFile(deltaVersion, resultMap)
971+
}
972+
973+
resultMap
974+
}
975+
976+
logDebug(s"Loading snapshot at version $snapshotVersion and apply delta files to version " +
977+
s"$endVersion takes $elapsedMs ms.")
978+
979+
result
980+
}
886981
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,80 @@ class RocksDB(
233233
this
234234
}
235235

236+
/**
237+
* Load from the start snapshot version and apply all the changelog records to reach the
238+
* end version. Note that this will copy all the necessary files from DFS to local disk as needed,
239+
* and possibly restart the native RocksDB instance.
240+
*
241+
* @param snapshotVersion version of the snapshot to start with
242+
* @param endVersion end version
243+
* @return A RocksDB instance loaded with the state endVersion replayed from snapshotVersion.
244+
* Note that the instance will be read-only since this method is only used in State Data
245+
* Source.
246+
*/
247+
def loadFromSnapshot(snapshotVersion: Long, endVersion: Long): RocksDB = {
248+
assert(snapshotVersion >= 0 && endVersion >= snapshotVersion)
249+
acquire(LoadStore)
250+
recordedMetrics = None
251+
logInfo(
252+
log"Loading snapshot at version ${MDC(LogKeys.VERSION_NUM, snapshotVersion)} and apply " +
253+
log"changelog files to version ${MDC(LogKeys.VERSION_NUM, endVersion)}.")
254+
try {
255+
replayFromCheckpoint(snapshotVersion, endVersion)
256+
257+
logInfo(
258+
log"Loaded snapshot at version ${MDC(LogKeys.VERSION_NUM, snapshotVersion)} and apply " +
259+
log"changelog files to version ${MDC(LogKeys.VERSION_NUM, endVersion)}.")
260+
} catch {
261+
case t: Throwable =>
262+
loadedVersion = -1 // invalidate loaded data
263+
throw t
264+
}
265+
this
266+
}
267+
268+
/**
269+
* Load from the start checkpoint version and apply all the changelog records to reach the
270+
* end version.
271+
* If the start version does not exist, it will throw an exception.
272+
*
273+
* @param snapshotVersion start checkpoint version
274+
* @param endVersion end version
275+
*/
276+
private def replayFromCheckpoint(snapshotVersion: Long, endVersion: Long): Any = {
277+
closeDB()
278+
val metadata = fileManager.loadCheckpointFromDfs(snapshotVersion, workingDir)
279+
loadedVersion = snapshotVersion
280+
281+
// reset last snapshot version
282+
if (lastSnapshotVersion > snapshotVersion) {
283+
// discard any newer snapshots
284+
lastSnapshotVersion = 0L
285+
latestSnapshot = None
286+
}
287+
openDB()
288+
289+
numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) {
290+
// we don't track the total number of rows - discard the number being track
291+
-1L
292+
} else if (metadata.numKeys < 0) {
293+
// we track the total number of rows, but the snapshot doesn't have tracking number
294+
// need to count keys now
295+
countKeys()
296+
} else {
297+
metadata.numKeys
298+
}
299+
if (loadedVersion != endVersion) replayChangelog(endVersion)
300+
// After changelog replay the numKeysOnWritingVersion will be updated to
301+
// the correct number of keys in the loaded version.
302+
numKeysOnLoadedVersion = numKeysOnWritingVersion
303+
fileManagerMetrics = fileManager.latestLoadCheckpointMetrics
304+
305+
if (conf.resetStatsOnLoad) {
306+
nativeStats.reset
307+
}
308+
}
309+
236310
/**
237311
* Replay change log from the loaded version to the target version.
238312
*/

0 commit comments

Comments
 (0)