/* * Copyright (c) 2017 - 2019 CiBO Technologies - All Rights Reserved * You may use, distribute, and modify this code under the * terms of the BSD 3-Clause license. * * A copy of the license can be found on the root of this repository, * at https://github.com/cibotech/ScalaStan/blob/master/LICENSE, * or at https://opensource.org/licenses/BSD-3-Clause */ package com.cibo.scalastan import java.io.Writer import com.cibo.scalastan.ast.{StanDataDeclaration, StanParameterDeclaration} import com.cibo.scalastan.run.StanRunner sealed trait InitialValue case object DefaultInitialValue extends InitialValue case class InitialValueDouble(v: Double) extends InitialValue case class InitialValueMapping(mapping: Map[String, DataMapping[_]]) extends InitialValue case class CompiledModel( model: StanModel, runner: StanRunner, dataMapping: Map[String, DataMapping[_]] = Map.empty, initialValue: InitialValue = DefaultInitialValue ) { private def emitMapping(mapping: Map[String, DataMapping[_]], writer: Writer): Unit = { mapping.values.foreach { value => writer.write(value.emit) writer.write("\n") } } final def emitData(writer: Writer): Unit = emitMapping(dataMapping, writer) final def emitInitialValues(writer: Writer): Unit = initialValue match { case InitialValueMapping(mapping) => emitMapping(mapping, writer) case _ => () } /** Get the specified input data. */ final def get[T <: StanType, R]( decl: StanDataDeclaration[T] ): T#SCALA_TYPE = dataMapping(decl.emit).values.asInstanceOf[T#SCALA_TYPE] /** Reset all bindings. */ final def reset: CompiledModel = copy( dataMapping = Map.empty, initialValue = DefaultInitialValue ) /** Look up and set size declarations. */ private def setSizes[T <: StanType](valueType: T, data: Any): CompiledModel = { valueType.getIndices.foldLeft((this, data)) { case ((old, d), dim) => val ds = d.asInstanceOf[Seq[_]] val next = if (ds.nonEmpty) ds.head else Seq.empty dim match { case indexDecl: StanDataDeclaration[StanInt] => (old.withData(indexDecl, ds.size), next) case _ => (old, next) } }._1 } /** Add a data binding. */ final def withData[T <: StanType, V]( decl: StanDataDeclaration[T], data: V )(implicit ev: V <:< T#SCALA_TYPE): CompiledModel = { val conv = data.asInstanceOf[T#SCALA_TYPE] // Check if this parameter has already been assigned and throw an exception if the values are conflicting. dataMapping.get(decl.emit) match { case Some(s) if s.values != data => throw new IllegalStateException(s"conflicting values assigned to ${decl.name}") case _ => () } // Insert/check size declarations. val withDecls = setSizes(decl.returnType, conv) // Insert the binding. withDecls.copy( dataMapping = withDecls.dataMapping.updated(decl.emit, DataMapping[T](decl, conv)) ) } /** Add a binding from a data source. */ final def withData[T <: StanType, V]( value: (StanDataDeclaration[T], V) )(implicit ev: V <:< T#SCALA_TYPE): CompiledModel = withData(value._1, value._2) /** Set the initial value for a parameter. */ final def withInitialValue[T <: StanType, V]( decl: StanParameterDeclaration[T], value: V )(implicit ev: V <:< T#SCALA_TYPE): CompiledModel = { val conv = value.asInstanceOf[T#SCALA_TYPE] // Insert/check size declarations val withDecls = setSizes(decl.returnType, conv) // Record the initial value. val newValue = decl.emit -> DataMapping[T](decl, conv) initialValue match { case DefaultInitialValue => withDecls.copy(initialValue = InitialValueMapping(Map(newValue))) case InitialValueMapping(mapping) => withDecls.copy(initialValue = InitialValueMapping(mapping + newValue)) case InitialValueDouble(_) => throw new IllegalStateException("Initial value already set.") } } /** Set the bounds on initial values. */ final def withInitialValue(value: Double): CompiledModel = { require(value >= 0, s"The upper bound on the initial value must be >= 0, got $value") initialValue match { case DefaultInitialValue => copy(initialValue = InitialValueDouble(value)) case _ => throw new IllegalStateException("Initial value already set.") } } /** Run the model and get results. */ final def run( chains: Int = 4, seed: Int = -1, cache: Boolean = true, method: RunMethod.Method = RunMethod.Sample() ): StanResults = { require(chains > 0, s"Must run at least one chain") // Make sure all the necessary data is provided. model.program.data.filterNot(v => dataMapping.contains(v.emit)).foreach { v => throw new IllegalStateException(s"data not supplied for ${v.name}") } runner.run( compiledModel = this, chains = chains, seed = seed, cache = cache, method = method ) } }