/*
 * Copyright 2017 by Simba Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package org.apache.spark.sql.simba.util

import org.apache.spark.sql.simba.{ShapeSerializer, ShapeType}
import org.apache.spark.sql.simba.expression.PointWrapper
import org.apache.spark.sql.simba.spatial.{Point, Shape}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, UnsafeArrayData}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan

/**
  * Created by dongx on 11/12/2016.
  */
object ShapeUtils {
  def getPointFromRow(row: InternalRow, columns: List[Attribute], plan: SparkPlan,
                      isPoint: Boolean): Point = {
    if (isPoint) {
      ShapeSerializer.deserialize(BindReferences.bindReference(columns.head, plan.output)
        .eval(row).asInstanceOf[UnsafeArrayData].toByteArray).asInstanceOf[Point]
    } else {
      Point(columns.toArray.map(BindReferences.bindReference(_, plan.output).eval(row)
        .asInstanceOf[Number].doubleValue()))
    }
  }
  def getPointFromRow(row: InternalRow, columns: List[Attribute], plan: LogicalPlan,
                      isPoint: Boolean): Point = {
    if (isPoint) {
      ShapeSerializer.deserialize(BindReferences.bindReference(columns.head, plan.output)
        .eval(row).asInstanceOf[UnsafeArrayData].toByteArray).asInstanceOf[Point]
    } else {
      Point(columns.toArray.map(BindReferences.bindReference(_, plan.output).eval(row)
        .asInstanceOf[Number].doubleValue()))
    }
  }

  def getShape(expression: Expression, input: InternalRow): Shape = {
    if (!expression.isInstanceOf[PointWrapper] && expression.dataType.isInstanceOf[ShapeType]) {
      ShapeSerializer.deserialize(expression.eval(input).asInstanceOf[UnsafeArrayData].toByteArray)
    } else if (expression.isInstanceOf[PointWrapper]) {
      expression.eval(input).asInstanceOf[Shape]
    } else throw new UnsupportedOperationException("Query shape should be of ShapeType")
  }

  def getShape(expression: Expression, schema: Seq[Attribute], input: InternalRow): Shape = {
    if (!expression.isInstanceOf[PointWrapper] && expression.dataType.isInstanceOf[ShapeType]) {
      ShapeSerializer.deserialize(BindReferences.bindReference(expression, schema)
        .eval(input).asInstanceOf[UnsafeArrayData].toByteArray)
    } else if (expression.isInstanceOf[PointWrapper]) {
      BindReferences.bindReference(expression, schema).eval(input).asInstanceOf[Shape]
    } else throw new UnsupportedOperationException("Query shape should be of ShapeType")
  }

}