package ru.chermenin.spark.sql.execution.streaming.state import org.apache.hadoop.fs.Path import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.SparkSession.Builder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.DataStreamWriter import ru.chermenin.spark.sql.execution.streaming.state.RocksDbStateStoreProvider._ import scala.collection.mutable /** * Implicits aka helper methods * * The can be imported into scope with , * import ru.chermenin.spark.sql.execution.streaming.state.implicits._ * * SessionImplicits: * - Makes the `useRocksDBStateStore` method available on [[Builder]] * - Sets provider to [[RocksDbStateStoreProvider]] * * WriterImplicits: * - Makes the `stateTimeout` method available on [[DataStreamWriter]] * - Precedence is given to the provided arguments (if any), then previously set value, * and finally the value set on [[RuntimeConfig]] (in case of checkpoint location) * - Makes Checkpoint mandatory for all query on which applied * - Expiry Seconds less than 0 are treated as -1 (no timeout) */ object implicits extends Serializable { implicit class SessionImplicits(sparkSessionBuilder: Builder) { def useRocksDBStateStore(): Builder = sparkSessionBuilder.config(SQLConf.STATE_STORE_PROVIDER_CLASS.key, classOf[RocksDbStateStoreProvider].getCanonicalName) } implicit class WriterImplicits[T](dsw: DataStreamWriter[T]) { def stateTimeout(runtimeConfig: RuntimeConfig, queryName: String = "", expirySecs: Int = DEFAULT_STATE_EXPIRY_SECS.toInt, checkpointLocation: String = ""): DataStreamWriter[T] = { val extraOptions = getExtraOptions val name = queryName match { case "" | null => extraOptions.getOrElse("queryName", UNNAMED_QUERY) case _ => queryName } val location = new Path(checkpointLocation match { case "" | null => extraOptions.getOrElse("checkpointLocation", runtimeConfig.getOption(SQLConf.CHECKPOINT_LOCATION.key ).getOrElse(throw new IllegalStateException( "Checkpoint Location must be specified for State Expiry either " + """through option("checkpointLocation", ...) or """ + s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""")) ) case _ => checkpointLocation }, name) .toUri.toString runtimeConfig.set(s"$STATE_EXPIRY_SECS.$name", if (expirySecs < 0) -1 else expirySecs) dsw .queryName(name) .option("checkpointLocation", location) } private def getExtraOptions: mutable.HashMap[String, String] = { val className = classOf[DataStreamWriter[T]] val field = className.getDeclaredField("extraOptions") field.setAccessible(true) field.get(dsw).asInstanceOf[mutable.HashMap[String, String]] } } }