package glintlda

import breeze.linalg.{DenseVector, SparseVector, sum}
import glintlda.util.FastRNG

/**
  * A Gibbs sample
  * @param features The (sequential and ordered) features
  * @param topics The assigned topics
  */
class GibbsSample(val features: Array[Int],
                  val topics: Array[Int]) extends Serializable {

  /**
    * Returns a dense topic count
    * @param nrOfTopics The total number of topics
    * @return The dense topic count
    */
  def denseCounts(nrOfTopics: Int): DenseVector[Int] = {
    val result = DenseVector.zeros[Int](nrOfTopics)
    var i = 0
    while (i < topics.length) {
      result(topics(i)) += 1
      i += 1
    }
    result
  }

  /**
    * Returns a sparse topic count
    *
    * @param nrOfTopics The total number of topics
    * @return The sparse topic count
    */
  def sparseCounts(nrOfTopics: Int): SparseVector[Int] = {
    val result = SparseVector.zeros[Int](nrOfTopics)
    var i = 0
    while (i < topics.length) {
      result.update(topics(i), result(topics(i)) + 1)
      i += 1
    }
    result
  }

}

object GibbsSample {

  /**
    * Initializes a Gibbs sample with random (uniform) topic assignments
    *
    * @param sv The sparse vector representing the document
    * @param random The random number generator
    * @param topics The number of topics
    * @return An initialized Gibbs sample with random (uniform) topic assignments
    */
  def apply(sv: SparseVector[Int], random: FastRNG, topics: Int): GibbsSample = {
    val totalTokens = sum(sv)
    val sample = new GibbsSample(new Array[Int](totalTokens), new Array[Int](totalTokens))

    var i = 0
    var current = 0
    while (i < sv.activeSize) {
      val index = sv.indexAt(i)
      var value = sv.valueAt(i)
      while (value > 0) {
        sample.features(current) = index
        sample.topics(current) = random.nextPositiveInt() % topics
        current += 1
        value -= 1
      }
      i += 1
    }

    sample
  }

}