package org.deeplearning4j.scalphagozero.experience

import org.nd4j.linalg.api.ndarray.INDArray
import scala.collection.mutable.ListBuffer

/**
  * Experience collector for AlphaGo Zero games. Collects encoded game states,
  * visit counts and rewards.
  *
  * @author Max Pumperla
  */
class ZeroExperienceCollector extends ExperienceCollector {

  val states: ListBuffer[INDArray] = ListBuffer.empty
  val visitCounts: ListBuffer[INDArray] = ListBuffer.empty
  val rewards: ListBuffer[INDArray] = ListBuffer.empty

  private val currentEpisodeStates: ListBuffer[INDArray] = ListBuffer.empty
  private val currentEpisodeVisitCounts: ListBuffer[INDArray] = ListBuffer.empty

  def clearAllBuffers(): Unit = {
    states.clear()
    visitCounts.clear()
    rewards.clear()
    clearEpisodeBuffers()
  }

  private def clearEpisodeBuffers(): Unit = {
    currentEpisodeStates.clear()
    currentEpisodeVisitCounts.clear()
  }

  override def beginEpisode(): Unit = clearEpisodeBuffers()

  override def recordDecision(state: INDArray, visitCounts: INDArray): Unit = {
    currentEpisodeStates += state
    currentEpisodeVisitCounts += visitCounts
    ()
  }

  override def completeEpisode(reward: INDArray): Unit = {
    states ++= currentEpisodeStates
    visitCounts ++= currentEpisodeVisitCounts
    for (_ <- 1 to currentEpisodeStates.size)
      rewards += reward

    clearEpisodeBuffers()
  }

}