/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.mllib.optimization.tfocs

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{ DenseVector, Vectors }
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.optimization.tfocs.DVectorFunctions._
import org.apache.spark.mllib.optimization.tfocs.VectorSpace._
import org.apache.spark.mllib.optimization.tfocs.vs.dvector.DVectorSpace
import org.apache.spark.mllib.optimization.tfocs.vs.dvectordouble.DVectorDoubleSpace
import org.apache.spark.mllib.optimization.tfocs.vs.vector.DenseVectorSpace

class VectorSpaceSuite extends FunSuite with MLlibTestSparkContext {

  test("DenseVectorSpace.combine is implemented properly") {
    val alpha = 1.1
    val a = new DenseVector(Array(2.0, 3.0))
    val beta = 4.0
    val b = new DenseVector(Array(5.0, 6.0))
    val expectedCombination = Vectors.dense(1.1 * 2.0 + 4.0 * 5.0, 1.1 * 3.0 + 4.0 * 6.0)
    assert(DenseVectorSpace.combine(alpha, a, beta, b) == expectedCombination,
      "DenseVectorSpace.combine should return the correct result.")
  }

  test("DenseVectorSpace.dot is implemented properly") {
    val a = new DenseVector(Array(2.0, 3.0))
    val b = new DenseVector(Array(5.0, 6.0))
    val expectedDot = 2.0 * 5.0 + 3.0 * 6.0
    assert(DenseVectorSpace.dot(a, b) == expectedDot,
      "DenseVectorSpace.dot should return the correct result.")
  }

  test("DVectorSpace.combine is implemented properly") {
    val alpha = 1.1
    val a = sc.parallelize(Array(new DenseVector(Array(2.0, 3.0)), new DenseVector(Array(4.0))), 2)
    val beta = 4.0
    val b = sc.parallelize(Array(new DenseVector(Array(5.0, 6.0)), new DenseVector(Array(7.0))), 2)
    val combination = DVectorSpace.combine(alpha, a, beta, b)
    val expectedCombination =
      Vectors.dense(1.1 * 2.0 + 4.0 * 5.0, 1.1 * 3.0 + 4.0 * 6.0, 1.1 * 4.0 + 4.0 * 7.0)
    assert(Vectors.dense(combination.collectElements) == expectedCombination,
      "DVectorSpace.combine should return the correct result.")
  }

  test("DVectorSpace.dot is implemented properly") {
    val a = sc.parallelize(Array(new DenseVector(Array(2.0, 3.0)), new DenseVector(Array(4.0))), 2)
    val b = sc.parallelize(Array(new DenseVector(Array(5.0, 6.0)), new DenseVector(Array(7.0))), 2)
    val expectedDot = 2.0 * 5.0 + 3.0 * 6.0 + 4.0 * 7.0
    assert(DVectorSpace.dot(a, b) == expectedDot,
      "DVectorSpace.dot should return the correct result.")
  }

  test("DVectorDoubleSpace.combine is implemented properly") {
    val alpha = 1.1
    val a = (sc.parallelize(Array(new DenseVector(Array(2.0, 3.0)), new DenseVector(Array(4.0))),
      2), 9.9)
    val beta = 4.0
    val b = (sc.parallelize(Array(new DenseVector(Array(5.0, 6.0)), new DenseVector(Array(7.0))),
      2), 11.11)
    val combination = DVectorDoubleSpace.combine(alpha, a, beta, b)
    val expectedCombination =
      (Vectors.dense(1.1 * 2.0 + 4.0 * 5.0, 1.1 * 3.0 + 4.0 * 6.0, 1.1 * 4.0 + 4.0 * 7.0),
        1.1 * 9.9 + 4.0 * 11.11)
    assert(Vectors.dense(combination._1.collectElements) == expectedCombination._1,
      "DVectorVectorSpace.combine should return the correct result.")
    assert(combination._2 == expectedCombination._2,
      "DVectorVectorSpace.combine should return the correct result.")
  }

  test("DVectorDoubleSpace.dot is implemented properly") {
    val a = (sc.parallelize(Array(new DenseVector(Array(2.0, 3.0)), new DenseVector(Array(4.0))),
      2), 9.9)
    val b = (sc.parallelize(Array(new DenseVector(Array(5.0, 6.0)), new DenseVector(Array(7.0))),
      2), 11.11)
    val expectedDot = 2.0 * 5.0 + 3.0 * 6.0 + 4.0 * 7.0 + 9.9 * 11.11
    assert(DVectorDoubleSpace.dot(a, b) == expectedDot,
      "DVectorVectorSpace.dot should return the correct result.")
  }
}