package com.twitter.finagle.postgres.generic import com.twitter.concurrent.AsyncStream import scala.collection.immutable.Queue import com.twitter.finagle.postgres.{Param, PostgresClient, Row} import com.twitter.util.Future import scala.language.existentials case class Query[T](parts: Seq[String], queryParams: Seq[QueryParam], cont: Row => T) { def stream(client: PostgresClient): AsyncStream[T] = { val (queryString, params) = impl client.prepareAndQueryToStream[T](queryString, params: _*)(cont) } def run(client: PostgresClient): Future[Seq[T]] = stream(client).toSeq def exec(client: PostgresClient): Future[Int] = { val (queryString, params) = impl client.prepareAndExecute(queryString, params: _*) } def map[U](fn: T => U): Query[U] = copy(cont = cont andThen fn) def as[U](implicit rowDecoder: RowDecoder[U], columnNamer: ColumnNamer): Query[U] = { copy(cont = row => rowDecoder(row)(columnNamer)) } private def impl: (String, Seq[Param[_]]) = { val (last, placeholders, params) = queryParams.foldLeft((1, Queue.empty[Seq[String]], Queue.empty[Param[_]])) { case ((start, placeholders, params), next) => val nextPlaceholders = next.placeholders(start) val nextParams = Queue(next.params: _*) (start + nextParams.length, placeholders enqueue nextPlaceholders, params ++ nextParams) } val queryString = parts.zipAll(placeholders, "", Seq.empty).flatMap { case (part, ph) => Seq(part, ph.mkString(", ")) }.mkString (queryString, params) } } object Query { implicit class RowQueryOps(val self: Query[Row]) extends AnyVal { def ++(that: Query[Row]): Query[Row] = Query[Row]( parts = if(self.parts.length > self.queryParams.length) (self.parts.dropRight(1) :+ (self.parts.lastOption.getOrElse("") + that.parts.headOption.getOrElse(""))) ++ that.parts.drop(1) else self.parts ++ that.parts, queryParams = self.queryParams ++ that.queryParams, cont = self.cont ) def ++(that: String): Query[Row] = Query[Row]( parts = if(self.parts.length > self.queryParams.length) self.parts.dropRight(1) :+ (self.parts.last + that) else self.parts :+ that, queryParams = self.queryParams, cont = self.cont ) } }