package glint.models.client.granular

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag

import glint.models.client.BigVector

/**
  * A [[glint.models.client.BigVector BigVector]] whose messages are limited to a specific maximum message size. This
  * helps resolve timeout exceptions and heartbeat failures in akka at the cost of additional message overhead.
  *
  * {{{
  *   vector = client.vector[Double](1000000)
  *   granularVector = new GranularBigVector[Double](vector, 1000)
  *   granularVector.pull(veryLargeArrayOfIndices)
  * }}}
  *
  * @param underlying The underlying big vector
  * @param maximumMessageSize The maximum message size
  * @tparam V The type of values to store
  */
class GranularBigVector[V: ClassTag](underlying: BigVector[V],
                                     maximumMessageSize: Int) extends BigVector[V] {

  require(maximumMessageSize > 0, "Max message size must be non-zero")

  val size = underlying.size

  /**
    * Pulls a set of values while attempting to keep individual network messages smaller
    * than `maximumMessageSize`
     *
     * @param keys The indices of the elements
     * @param ec The implicit execution context in which to execute the request
     * @return A future containing the values of the elements at given rows, columns
     */
  override def pull(keys: Array[Long])(implicit ec: ExecutionContext): Future[Array[V]] = {
    var i = 0
    var current = 0
    val maxIncrement = Math.max(1, maximumMessageSize)
    val a = new Array[Future[Array[V]]](Math.ceil(keys.length.toDouble / maxIncrement.toDouble).toInt)
    while (i < keys.length) {
      val end = Math.min(keys.length, i + maxIncrement)
      val future = underlying.pull(keys.slice(i, end))
      a(current) = future
      current += 1
      i += maxIncrement
    }
    Future.sequence(a.toIterator).map {
      case arrayOfValues =>
        val finalValues = new ArrayBuffer[V](keys.length)
        arrayOfValues.foreach(x => finalValues.appendAll(x))
        finalValues.toArray
    }
  }

  override def destroy()(implicit ec: ExecutionContext): Future[Boolean] = underlying.destroy()

  /**
    * Pushes a set of values while keeping individual network messages smaller
    * than `maximumMessageSize`
     *
     * @param keys The indices of the rows
     * @param values The values to update
     * @param ec The implicit execution context in which to execute the request
     * @return A future containing either the success or failure of the operation
     */
  override def push(keys: Array[Long], values: Array[V])
    (implicit ec: ExecutionContext): Future[Boolean] = {
    var i = 0
    val ab = new ArrayBuffer[Future[Boolean]](keys.length / maximumMessageSize)
    while (i < keys.length) {
      val end = Math.min(keys.length, i + maximumMessageSize)
      val future = underlying.push(keys.slice(i, end), values.slice(i, end))
      ab.append(future)
      i += maximumMessageSize
    }
    Future.sequence(ab.toIterator).transform(x => x.forall(y => y), err => err)
  }

}