package com.soundcloud.lsh

import com.soundcloud.TestHelper
import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix, MatrixEntry}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.scalatest.{FunSuite, Matchers}

class QueryHammingTest
  extends FunSuite
    with SparkLocalContext
    with Matchers
    with TestHelper {

  def denseVector(input: Double*): Vector = {
    Vectors.dense(input.toArray)
  }

  val queryVectorA = denseVector(1.0, 1.0)
  val queryVectorB = denseVector(-1.0, 1.0)

  val catalogVectorA = denseVector(1.0, 1.0)
  val catalogVectorB = denseVector(-1.0, 1.0)
  val catalogVectorC = denseVector(-1.0, 0.5)
  val catalogVectorD = denseVector(1.0, 0.5)

  val queryRows = Seq(
    IndexedRow(0, queryVectorA),
    IndexedRow(1, queryVectorB)
  )

  val catalogRows = Seq(
    IndexedRow(0, catalogVectorA),
    IndexedRow(1, catalogVectorB),
    IndexedRow(2, catalogVectorC),
    IndexedRow(3, catalogVectorD)
  )

  val expected = Array(
    MatrixEntry(0, 0, Cosine(queryVectorA, catalogVectorA)),
    MatrixEntry(0, 3, Cosine(queryVectorA, catalogVectorD)),
    MatrixEntry(1, 1, Cosine(queryVectorB, catalogVectorB)),
    MatrixEntry(1, 2, Cosine(queryVectorB, catalogVectorC))
  )

  test("broadcast catalog") {
    val queryMatrix = new IndexedRowMatrix(sc.parallelize(queryRows))
    val catalogMatrix = new IndexedRowMatrix(sc.parallelize(catalogRows))

    val queryNearestNeighbour = new QueryHamming(0.1, 10000, 2, true)
    val got = queryNearestNeighbour.join(queryMatrix, catalogMatrix).entries.collect

    implicit val equality = new MatrixEquality(0.02)
    got.sortBy(t => (t.i, t.j)) should equal(expected)
  }

  test("broadcast query") {
    val queryMatrix = new IndexedRowMatrix(sc.parallelize(queryRows))
    val catalogMatrix = new IndexedRowMatrix(sc.parallelize(catalogRows))

    val queryNearestNeighbour = new QueryHamming(0.1, 10000, 2, false)
    val got = queryNearestNeighbour.join(queryMatrix, catalogMatrix).entries.collect

    implicit val equality = new MatrixEquality(0.02)
    got.sortBy(t => (t.i, t.j)) should equal(expected)
  }

}