package glintlda.naive

import breeze.linalg.{DenseVector, Vector}
import breeze.stats.distributions.Multinomial
import glintlda.LDAConfig
import glintlda.util.FastRNG

/**
  * A naive sampler using the basic collapsed Gibbs sampling probabilities and draws from the unnormalized distribution
  * using a regular cumulative approach.
  *
  * @param config The LDA configuration
  * @param random The random number generator
  */
class Sampler(config: LDAConfig, random: FastRNG) {

  private val α = config.α
  private val β = config.β
  private val αSum = config.topics * α
  private val βSum = config.vocabularyTerms * config.β

  var infer: Int = 1
  var wordCounts: Vector[Long] = null
  var globalCounts: Array[Long] = null
  var documentCounts: Vector[Int] = null

  /**
    * Produces a new topic for given feature and old topic
    *
    * @param feature The feature
    * @param oldTopic The old topic
    * @return
    */
  def sampleFeature(feature: Int, oldTopic: Int): Int = {
    var i = 0
    val p = DenseVector.zeros[Double](config.topics)
    var sum = 0.0
    while (i < config.topics) {
      p(i) = (documentCounts(i) + α) * ((wordCounts(i) + β) / (globalCounts(i) + βSum))
      sum += p(i)
      i += 1
    }
    p /= sum
    Multinomial(p).draw()
  }

}