package org.betterers.spark.gis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, Row}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.betterers.spark.gis.udf.Functions

/**
 * UDF test suite
 */
class UDFTest extends FunSuite with BeforeAndAfter {
  import Geometry.WGS84

  val point = Geometry.point((2.0, 2.0))
  val multiPoint = Geometry.multiPoint((1.0, 1.0), (2.0, 2.0), (3.0, 3.0))
  var line = Geometry.line((11.0, 11.0), (12.0, 12.0))
  var multiLine = Geometry.multiLine(
    Seq((11.0, 1.0), (23.0, 23.0)),
    Seq((31.0, 3.0), (42.0, 42.0)))
  var polygon = Geometry.polygon((1.0, 1.0), (2.0, 2.0), (3.0, 1.0))
  var multiPolygon = Geometry.multiPolygon(
    Seq((1.0, 1.0), (2.0, 2.0), (3.0, 1.0)),
    Seq((1.1, 1.1), (2.0, 1.9), (2.5, 1.1))
  )
  val collection = Geometry.collection(point, multiPoint, line)
  val all: Seq[Geometry] = Seq(point, multiPoint, line, multiLine, polygon, multiPolygon, collection)

  var sc: SparkContext = _
  var sql: SQLContext = _

  before {
    sc = new SparkContext(new SparkConf().setMaster("local[4]").setAppName("SparkGIS"))
    sql = new SQLContext(sc)
  }

  after {
    sc.stop()
  }

  test("ST_Boundary") {
    // all.foreach(g => println(Functions.ST_Boundary(g).toString))

    assertResult(true) {
      Functions.ST_Boundary(point).isEmpty
    }
    assertResult(true) {
      Functions.ST_Boundary(multiPoint).isEmpty
    }
    assertResult("Some(MULTIPOINT ((11 11), (12 12)))") {
      Functions.ST_Boundary(line).toString
    }
    assertResult(None) {
      Functions.ST_Boundary(multiLine)
    }
    assertResult("Some(LINEARRING (1 1, 2 2, 3 1, 1 1))") {
      Functions.ST_Boundary(polygon).toString
    }
    assertResult(None) {
      Functions.ST_Boundary(multiPolygon)
    }
    assertResult(None) {
      Functions.ST_Boundary(collection)
    }
  }

  test("ST_CoordDim") {
    all.foreach(g => {
      assertResult(3) {
        Functions.ST_CoordDim(g)
      }
    })
  }

  test("UDF in SQL") {
    val schema = StructType(Seq(
      StructField("id", IntegerType),
      StructField("geo", GeometryType.Instance)
    ))
    val jsons = Map(
      (1, "{\"type\":\"Point\",\"coordinates\":[1,1]}}"),
      (2, "{\"type\":\"LineString\",\"coordinates\":[[12,13],[15,20]]}}")
    )
    val rdd = sc.parallelize(Seq(
      "{\"id\":1,\"geo\":" + jsons(1) + "}",
      "{\"id\":2,\"geo\":" + jsons(2) + "}"
    ))
    rdd.name = "TEST"
    val df = sql.read.schema(schema).json(rdd)
    df.registerTempTable("TEST")
    Functions.register(sql)
    assertResult(Array(3,3)) {
      sql.sql("SELECT ST_CoordDim(geo) FROM TEST").collect().map(_.get(0))
    }
  }
}