// Copyright (c) 2018-2020 by Rob Norris
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.net.protocol

import cats.effect.Resource
import cats.implicits._
import cats.MonadError
import skunk.exception.PostgresErrorException
import skunk.net.message.{ Bind => BindMessage, Close => _, _ }
import skunk.net.MessageSocket
import skunk.net.Protocol.{ PreparedStatement, PortalId }
import skunk.util.{ Origin, Namer }
import natchez.Trace

trait Bind[F[_]] {

  def apply[A](
    statement:  PreparedStatement[F, A],
    args:       A,
    argsOrigin: Origin
  ): Resource[F, PortalId]

}

object Bind {

  def apply[F[_]: MonadError[?[_], Throwable]: Exchange: MessageSocket: Namer: Trace]: Bind[F] =
    new Bind[F] {

      override def apply[A](
        statement:  PreparedStatement[F, A],
        args:       A,
        argsOrigin: Origin
      ): Resource[F, PortalId] =
        Resource.make {
          exchange("bind") {
            for {
              pn <- nextName("portal").map(PortalId)
              _  <- Trace[F].put(
                      "arguments" -> args.toString,
                      "portal-id" -> pn.value
                    )
              _  <- send(BindMessage(pn.value, statement.id.value, statement.statement.encoder.encode(args)))
              _  <- send(Flush)
              _  <- flatExpect {
                      case BindComplete        => ().pure[F]
                      case ErrorResponse(info) => syncAndFail(statement,  args, argsOrigin, info)
                    }
            } yield pn
          }
        } { Close[F].apply }

      def syncAndFail[A](
        statement:  PreparedStatement[F, A],
        args:       A,
        argsOrigin: Origin,
        info:       Map[Char, String]
      ): F[Unit] =
        for {
          hi <- history(Int.MaxValue)
          _  <- send(Sync)
          _  <- expect { case ReadyForQuery(_) => }
          a  <- PostgresErrorException.raiseError[F, Unit](
                  sql             = statement.statement.sql,
                  sqlOrigin       = Some(statement.statement.origin),
                  info            = info,
                  history         = hi,
                  arguments       = statement.statement.encoder.types.zip(statement.statement.encoder.encode(args)),
                  argumentsOrigin = Some(argsOrigin)
                )
        } yield a

    }

}