package scalismo.statisticalmodel.experimental

import breeze.linalg.DenseVector
import scalismo.common._
import scalismo.geometry._
import scalismo.mesh._
import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess
import scalismo.utils.Random

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

trait StatisticalVolumeIntensityModel[S] {

  def referenceMeshField: ScalarVolumeMeshField[S]

  def shape: StatisticalVolumeMeshModel

  def intensity: DiscreteLowRankGaussianProcess[_3D, UnstructuredPointsDomain[_3D], S]

  def mean: ScalarVolumeMeshField[S]

  def instance(coefficients: SVIMCoefficients): ScalarVolumeMeshField[S]

  def sample()(implicit rnd: Random): ScalarVolumeMeshField[S]

  def zeroCoefficients: SVIMCoefficients
}

object StatisticalVolumeIntensityModel {

  def apply[S: Scalar: TypeTag: ClassTag](
    referenceMeshField: ScalarVolumeMeshField[S],
    shape: StatisticalVolumeMeshModel,
    intensity: DiscreteLowRankGaussianProcess[_3D, UnstructuredPointsDomain[_3D], S]
  ): SVIM[S] = {
    SVIM(referenceMeshField, shape, intensity)
  }

}

case class SVIM[S: Scalar: TypeTag: ClassTag](
  referenceMeshField: ScalarVolumeMeshField[S],
  shape: StatisticalVolumeMeshModel,
  intensity: DiscreteLowRankGaussianProcess[_3D, UnstructuredPointsDomain[_3D], S]
) extends StatisticalVolumeIntensityModel[S] {

  override def mean: ScalarVolumeMeshField[S] = {
    ScalarVolumeMeshField(shape.mean, warpReferenceIntensity(intensity.mean.data))
  }

  override def instance(coefficients: SVIMCoefficients): ScalarVolumeMeshField[S] = {
    ScalarVolumeMeshField(shape.instance(coefficients.shape),
                          warpReferenceIntensity(intensity.instance(coefficients.intensity).data))
  }

  override def sample()(implicit rnd: Random): ScalarVolumeMeshField[S] = {
    ScalarVolumeMeshField(shape.sample(), warpReferenceIntensity(intensity.sample().data))
  }

  override def zeroCoefficients: SVIMCoefficients = SVIMCoefficients(
    DenseVector.zeros[Double](shape.rank),
    DenseVector.zeros[Double](intensity.rank)
  )

  def truncate(shapeComps: Int, colorComps: Int): SVIM[S] = {
    require(shapeComps >= 0 && shapeComps <= shape.rank, "illegal number of reduced shape components")
    require(colorComps >= 0 && colorComps <= intensity.rank, "illegal number of reduced color components")

    SVIM(
      referenceMeshField,
      shape.truncate(shapeComps),
      intensity.truncate(colorComps)
    )
  }

  private def warpReferenceIntensity(scalarData: IndexedSeq[S]): ScalarArray[S] = {
    ScalarArray[S](
      referenceMeshField.data
        .zip(ScalarArray[S](scalarData.toArray))
        .map { case (r, s) => Scalar[S].plus(r, s) }
        .toArray
    )
  }
}