/*
 * Copyright (c) 2017 - 2019 CiBO Technologies - All Rights Reserved
 * You may use, distribute, and modify this code under the
 * terms of the BSD 3-Clause license.
 *
 * A copy of the license can be found on the root of this repository,
 * at https://github.com/cibotech/ScalaStan/blob/master/LICENSE,
 * or at https://opensource.org/licenses/BSD-3-Clause
 */

package com.cibo.scalastan

import java.io.PrintWriter

import com.cibo.scalastan.ast._

trait StanCodeBlock extends Implicits {

  implicit val _context: StanContext
  implicit val _code: StanProgramBuilder = new StanProgramBuilder

  object stan extends StanFunctions with StanDistributions

  private def insertLocal[T <: StanType](
    typeConstructor: T,
    valueOpt: Option[StanValue[T]],
    name: sourcecode.Name
  ): StanLocalDeclaration[T] = {
    if (typeConstructor.lower.isDefined || typeConstructor.upper.isDefined) {
      throw new IllegalStateException("local variables may not have constraints")
    }

    val decl = StanLocalDeclaration[T](typeConstructor, _context.fixName(name.value))
    _code.insert(StanInlineDeclaration(decl, valueOpt))
    decl
  }

  def local[T <: StanType](typeConstructor: T)(implicit name: sourcecode.Name): StanLocalDeclaration[T] = {
    insertLocal(typeConstructor, None, name)
  }

  def local[T <: StanType](
    typeConstructor: T,
    value: StanValue[T]
  )(implicit name: sourcecode.Name): StanLocalDeclaration[T] = {
    insertLocal(typeConstructor, Some(value), name)
  }

  case class when(cond: StanValue[StanInt])(block: => Unit) {
    _code.enter()
    block
    _code.leave(code => ast.StanIfStatement(Seq((cond, StanBlock(code))), None))

    def when(cond: StanValue[StanInt])(otherBlock: => Unit): when = {
      _code.enter()
      otherBlock
      _code.handleElseIf(cond)
      this
    }

    def otherwise(otherBlock: => Unit): Unit = {
      _code.enter()
      otherBlock
      _code.handleElse()
    }
  }

  def when[T <: StanType](cond: StanValue[StanInt], ifTrue: StanValue[T], ifFalse: StanValue[T]): StanValue[T] = {
    StanTernaryOperator(cond, ifTrue, ifFalse)
  }

  def range(start: StanValue[StanInt], end: StanValue[StanInt]): StanValueRange = StanValueRange(start, end)

  def loop(cond: StanValue[StanInt])(body: => Unit): Unit = {
    _code.enter()
    body
    _code.leave(children => ast.StanWhileLoop(cond, StanBlock(children)))
  }

  def break: Unit = {
    _code.append(StanBreakStatement())
  }

  def continue: Unit = {
    _code.append(StanContinueStatement())
  }

  private[scalastan] def emitTopLevelLocals(writer: PrintWriter): Unit = {
    // Values have to be declared before code.  Since we treat transformations
    // differently, we need to make a special pass to combine the top-level locals.
    _code.results.children.foreach { child =>
      if (child.isInstanceOf[StanInlineDeclaration]) {
        child.emit(writer, 1)
      }
    }
  }

  private[scalastan] def emitCode(writer: PrintWriter): Unit = {
    _code.results.children.foreach { child =>
      if (!child.isInstanceOf[StanInlineDeclaration]) {
        child.emit(writer, 1)
      }
    }
  }
}