package com.github.jparkie.spark.elasticsearch import com.github.jparkie.spark.elasticsearch.conf.{ SparkEsMapperConf, SparkEsWriteConf } import com.github.jparkie.spark.elasticsearch.sql.{ SparkEsDataFrameMapper, SparkEsDataFrameSerializer } import com.holdenkarau.spark.testing.SharedSparkContext import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType } import org.apache.spark.sql.{ Row, SQLContext } import org.scalatest.{ MustMatchers, WordSpec } class SparkEsBulkWriterSpec extends WordSpec with MustMatchers with SharedSparkContext { val esServer = new ElasticSearchServer() override def beforeAll(): Unit = { super.beforeAll() esServer.start() } override def afterAll(): Unit = { esServer.stop() super.afterAll() } "SparkEsBulkWriter" must { "execute write() successfully" in { esServer.createAndWaitForIndex("test_index") val sqlContext = new SQLContext(sc) val inputSparkEsWriteConf = SparkEsWriteConf( bulkActions = 10, bulkSizeInMB = 1, concurrentRequests = 0, flushTimeoutInSeconds = 1 ) val inputMapperConf = SparkEsMapperConf( esMappingId = Some("id"), esMappingParent = None, esMappingVersion = None, esMappingVersionType = None, esMappingRouting = None, esMappingTTLInMillis = None, esMappingTimestamp = None ) val inputSchema = StructType( Array( StructField("id", StringType, true), StructField("parent", StringType, true), StructField("version", LongType, true), StructField("routing", StringType, true), StructField("ttl", LongType, true), StructField("timestamp", StringType, true), StructField("value", LongType, true) ) ) val inputData = sc.parallelize { Array( Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L), Row("TEST_ID_1", "TEST_PARENT_2", 2L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 2L), Row("TEST_ID_1", "TEST_PARENT_3", 3L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 3L), Row("TEST_ID_1", "TEST_PARENT_4", 4L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 4L), Row("TEST_ID_1", "TEST_PARENT_5", 5L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 5L), Row("TEST_ID_5", "TEST_PARENT_6", 6L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 6L), Row("TEST_ID_6", "TEST_PARENT_7", 7L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 7L), Row("TEST_ID_7", "TEST_PARENT_8", 8L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 8L), Row("TEST_ID_8", "TEST_PARENT_9", 9L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 9L), Row("TEST_ID_9", "TEST_PARENT_10", 10L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 10L), Row("TEST_ID_10", "TEST_PARENT_11", 11L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 11L) ) } val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema) val inputDataIterator = inputDataFrame.rdd.toLocalIterator val inputSparkEsBulkWriter = new SparkEsBulkWriter[Row]( esIndex = "test_index", esType = "test_type", esClient = () => esServer.client, sparkEsSerializer = new SparkEsDataFrameSerializer(inputSchema), sparkEsMapper = new SparkEsDataFrameMapper(inputMapperConf), sparkEsWriteConf = inputSparkEsWriteConf ) inputSparkEsBulkWriter.write(null, inputDataIterator) val outputGetResponse = esServer.client.prepareGet("test_index", "test_type", "TEST_ID_1").get() outputGetResponse.isExists mustEqual true outputGetResponse.getSource.get("parent").asInstanceOf[String] mustEqual "TEST_PARENT_5" outputGetResponse.getSource.get("version").asInstanceOf[Integer] mustEqual 5 outputGetResponse.getSource.get("routing").asInstanceOf[String] mustEqual "TEST_ROUTING_1" outputGetResponse.getSource.get("ttl").asInstanceOf[Integer] mustEqual 86400000 outputGetResponse.getSource.get("timestamp").asInstanceOf[String] mustEqual "TEST_TIMESTAMP_1" outputGetResponse.getSource.get("value").asInstanceOf[Integer] mustEqual 5 } } }