package org.allenai.pnp

import org.allenai.pnp.ExecutionScore.ExecutionScore
import org.scalatest.FlatSpec
import org.scalatest.Matchers

import com.jayantkrish.jklol.training.NullLogFunction

import edu.cmu.dynet._
import com.jayantkrish.jklol.util.IndexedList
import com.jayantkrish.jklol.training.NullLogFunction

class GlobalLoglikelihoodTrainerSpec extends FlatSpec with Matchers {
    
  Initialize.initialize()

  val TOLERANCE = 0.01



  "GlobalLoglikelihoodTrainer" should "train" in {
    val vocab = Array(0,1,2)

    val model = PnpModel.init(false)
    val startParam = model.addParameter("start", Dim(vocab.length))
    val transitionParam = model.addParameter("transition", Dim(vocab.length * vocab.length))

    def lm(k: Int): Pnp[Array[Int]] = {
      if (k == 1) {
        for {
          params <- Pnp.param("start")
          choice <- Pnp.choose(vocab, params, k - 1)
        } yield {
          Array(choice)
        }
      } else {
        for {
          rest <- lm(k - 1)
          previous = rest.last
          transition <- Pnp.param("transition")
          params = Expression.pickrange(
            transition, previous * vocab.length, (previous + 1) * vocab.length)
          choice <- Pnp.choose(vocab, params, k - 1)
        } yield {
          rest ++ Array(choice)
        }
      }
    }

    def makeOracle(label: Array[Int]): ExecutionScore = {
      new ExecutionScore() {
        def apply(tag: Any, choice: Any, env: Env): Double = {
          if (tag != null && tag.isInstanceOf[Int]) {
            val tagInt = tag.asInstanceOf[Int]
            if (tagInt >= 0 && tagInt < label.length) {
              if (choice == label(tagInt)) {
                0.0
              } else {
                Double.NegativeInfinity
              }
            } else {
              Double.NegativeInfinity
            }
          } else {
            0.0
          }
        }
      }
    }
    

    val examples = List(
        PnpExample(lm(3), lm(3), Env.init, makeOracle(Array(0,1,0))),
        PnpExample(lm(3), lm(3), Env.init, makeOracle(Array(0,1,2)))
    )

    val sgd = new SimpleSGDTrainer(model.model, 0.1f, 0.1f)
    val trainer = new GlobalLoglikelihoodTrainer(1000, 100, -1, model, sgd, new NullLogFunction())
    // val trainer = new BsoTrainer(100, 1, -1, model, sgd, new NullLogFunction())
    
    trainer.train(examples)
  }
}