package com.scylladb.migrator.writer import com.datastax.spark.connector.writer._ import com.datastax.spark.connector._ import com.scylladb.migrator.Connectors import com.scylladb.migrator.config.{ CopyType, Rename, TargetSettings } import org.apache.log4j.LogManager import org.apache.spark.sql.{ DataFrame, SparkSession } object Writer { case class TimestampColumns(ttl: String, writeTime: String) val log = LogManager.getLogger("com.scylladb.migrator.writer") def writeDataframe( target: TargetSettings, renames: List[Rename], df: DataFrame, timestampColumns: Option[TimestampColumns], tokenRangeAccumulator: Option[TokenRangeAccumulator])(implicit spark: SparkSession): Unit = { val connector = Connectors.targetConnector(spark.sparkContext.getConf, target) val writeConf = WriteConf .fromSparkConf(spark.sparkContext.getConf) .copy( ttl = timestampColumns.map(_.ttl).fold(TTLOption.defaultValue)(TTLOption.perRow), timestamp = timestampColumns .map(_.writeTime) .fold(TimestampOption.defaultValue)(TimestampOption.perRow) ) // Similarly to createDataFrame, when using withColumnRenamed, Spark tries // to re-encode the dataset. Instead we just use the modified schema from this // DataFrame; the access to the rows is positional anyway and the field names // are only used to construct the columns part of the INSERT statement. val renamedSchema = renames .foldLeft(df) { case (acc, Rename(from, to)) => acc.withColumnRenamed(from, to) } .schema log.info("Schema after renames:") log.info(renamedSchema.treeString) val columnSelector = timestampColumns match { case None => SomeColumns(renamedSchema.fields.map(_.name: ColumnRef): _*) case Some(TimestampColumns(ttl, writeTime)) => SomeColumns( renamedSchema.fields .map(x => x.name: ColumnRef) .filterNot(ref => ref.columnName == ttl || ref.columnName == writeTime): _*) } df.rdd.saveToCassandra( target.keyspace, target.table, columnSelector, writeConf, tokenRangeAccumulator = tokenRangeAccumulator )(connector, SqlRowWriter.Factory) } }