package kr.ac.kaist.ir.deep.train

import kr.ac.kaist.ir.deep.fn._
import kr.ac.kaist.ir.deep.network.{AutoEncoder, Network}
import kr.ac.kaist.ir.deep.rec.BinaryTree
import org.apache.spark.annotation.Experimental

/**
 * __Input Operation__ : VectorTree as Input & Unfolding Recursive Auto Encoder Training (no output type)
 *
 * ::Experimental::
 * @note This cannot be applied into non-AutoEncoder tasks
 * @note This is designed for Unfolding RAE, in
 *       [[http://ai.stanford.edu/~ang/papers/nips11-DynamicPoolingUnfoldingRecursiveAutoencoders.pdf this paper]]
 *
 * @param corrupt Corruption that supervises how to corrupt the input matrix. `(Default : [[kr.ac.kaist.ir.deep.train.NoCorruption]])`
 * @param error An objective function `(Default: [[kr.ac.kaist.ir.deep.fn.SquaredErr]])`
 *
 * @example
 * {{{var make = new URAEType(error = CrossEntropyErr)
 *            var corruptedIn = make corrupted in
 *            var out = make onewayTrip (net, corruptedIn)}}}
 */
@Experimental
class URAEType(override val corrupt: Corruption = NoCorruption,
               override val error: Objective = SquaredErr)
  extends TreeType {

  /**
   * Apply & Back-prop given single input
   *
   * @param net A network that gets input
   * @param delta Sequence of delta updates
   */
  def roundTrip(net: Network, delta: Seq[ScalarMatrix]) = (in: BinaryTree, real: Null) ⇒
    net match {
      case net: AutoEncoder ⇒
        val out = in forward net.encode

        // Decode phrase of reconstruction
        var terminals = in.backward(out, net.decode)
        while (terminals.nonEmpty) {
          val leaf = terminals.head
          terminals = terminals.tail

          leaf.out = error.derivative(leaf.out, leaf.x)
        }

        // Error propagation for decoder
        val err = in forward net.decode_!(delta.take(2).toIterator)

        // Error propagation for encoder
        in backward(err, net.encode_!(delta.takeRight(2).toIterator))
    }


  /**
   * Apply given input and compute the error
   *
   * @param net A network that gets input
   * @param pair (Input, Real output) for error computation.
   * @return error of this network
   */
  def lossOf(net: Network)(pair: (BinaryTree, Null)): Scalar =
    net match {
      case net: AutoEncoder ⇒
        var sum = 0.0f
        val in = pair._1
        // Encode phrase of Reconstruction
        val out = in forward net.apply

        // Decode phrase of reconstruction
        var terminals = in.backward(out, net.reconstruct)
        val size = terminals.size
        while (terminals.nonEmpty) {
          val leaf = terminals.head
          terminals = terminals.tail
          sum += error(leaf.out, leaf.x)
        }
        sum
      case _ ⇒ 0.0f
    }


  /**
   * Make validation output
   *
   * @return input as string
   */
  def stringOf(net: Network, pair: (BinaryTree, Null)): String =
    net match {
      case net: AutoEncoder ⇒
        val string = StringBuilder.newBuilder
        val in = pair._1
        // Encode phrase of Reconstruction
        val out = in forward net.apply

        // Decode phrase of reconstruction
        var terminals = in.backward(out, net.reconstruct)
        while (terminals.nonEmpty) {
          val leaf = terminals.head
          terminals = terminals.tail

          string append s"IN: ${leaf.x.mkString} URAE → OUT: ${leaf.out.mkString};"
        }
        string.mkString
      case _ ⇒ "NOT AN AUTOENCODER"
    }
}