package com.github.jparkie.spark.cassandra import com.datastax.driver.core.querybuilder.QueryBuilder import com.datastax.spark.connector.AllColumns import com.datastax.spark.connector.writer.{ RowWriterFactory, SqlRowWriter } import com.github.jparkie.spark.cassandra.client.SparkCassSSTableLoaderClientManager import com.github.jparkie.spark.cassandra.conf.{ SparkCassServerConf, SparkCassWriteConf } import com.holdenkarau.spark.testing.SharedSparkContext import org.apache.spark.sql.{ Row, SQLContext } import org.scalatest.{ MustMatchers, WordSpec } import scala.collection.JavaConverters._ class SparkCassBulkWriterSpec extends WordSpec with MustMatchers with CassandraServerSpecLike with SharedSparkContext { val testKeyspace = "test_keyspace" val testTable = "test_table" override def beforeAll(): Unit = { super.beforeAll() getCassandraConnector.withSessionDo { currentSession => createKeyspace(currentSession, testKeyspace) currentSession.execute( s"""CREATE TABLE $testKeyspace.$testTable ( | test_key BIGINT PRIMARY KEY, | test_value VARCHAR |); """.stripMargin ) } } "SparkCassBulkWriter" must { "write() successfully" in { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ implicit val testRowWriterFactory: RowWriterFactory[Row] = SqlRowWriter.Factory val testCassandraConnector = getCassandraConnector val testSparkCassWriteConf = SparkCassWriteConf() val testSparkCassServerConf = SparkCassServerConf( // See https://github.com/jsevellec/cassandra-unit/blob/master/cassandra-unit/src/main/resources/cu-cassandra.yaml storagePort = 7010 ) val testSparkCassBulkWriter = SparkCassBulkWriter( testCassandraConnector, testKeyspace, testTable, AllColumns, testSparkCassWriteConf, testSparkCassServerConf ) val testRDD = sc.parallelize(1 to 25) .map(currentNumber => (currentNumber.toLong, s"Hello World: $currentNumber!")) val testDataFrame = testRDD.toDF("test_key", "test_value") sc.runJob(testDataFrame.rdd, testSparkCassBulkWriter.write _) getCassandraConnector.withSessionDo { currentSession => val queryStatement = QueryBuilder.select("test_key", "test_value") .from(testKeyspace, testTable) .limit(25) val resultSet = currentSession.execute(queryStatement) val outputSet = resultSet.all.asScala .map(currentRow => (currentRow.getLong("test_key"), currentRow.getString("test_value"))) .toMap for (currentNumber <- 1 to 25) { val currentKey = currentNumber.toLong outputSet(currentKey) mustEqual s"Hello World: $currentNumber!" } } SparkCassSSTableLoaderClientManager.evictAll() } } }