/**
  * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
  * this work for additional information regarding copyright ownership.
  * The ASF licenses this file to You under the Apache License, Version 2.0
  * (the "License"); you may not use this file except in compliance with
  * the License.  You may obtain a copy of the License at
  * <p>
  * http://www.apache.org/licenses/LICENSE-2.0
  * <p>
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  **/
package com.hivehome.kafka.connect.sqs

import java.util.{List => JList, Map => JMap}
import javax.jms._

import org.apache.kafka.connect.data.Schema
import org.apache.kafka.connect.source.{SourceRecord, SourceTask}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal

object SQSSourceTask {
  private val SqsQueueField: String = "queue"
  private val MessageId: String = "messageId"
  private val ValueSchema = Schema.STRING_SCHEMA
}

class SQSSourceTask extends SourceTask {
  val logger = LoggerFactory.getLogger(getClass.getName)
  private var conf: Conf = _
  private var consumer: MessageConsumer = null
  // MessageId to MessageHandle used to ack the message on the commitRecord method invocation
  private var unAcknowledgedMessages = Map[String, Message]()

  def version: String = Version()

  def start(props: JMap[String, String]): Unit = {
    conf = Conf.parse(props.asScala.toMap).get

    logger.debug("Creating consumer...")
    synchronized {
      try {
        consumer = SQSConsumer(conf)
        logger.info("Created consumer to  SQS topic {} for reading", conf.queueName)
      }
      catch {
        case NonFatal(e) => logger.error("Exception", e)
      }
    }
  }

  import com.hivehome.kafka.connect.sqs.SQSSourceTask._

  @throws(classOf[InterruptedException])
  def poll: JList[SourceRecord] = {
    def toRecord(msg: Message): SourceRecord = {
      val extracted = MessageExtractor(msg)
      val key = Map(SqsQueueField -> conf.queueName.get).asJava
      val value = Map(MessageId -> msg.getJMSMessageID).asJava
      new SourceRecord(key, value, conf.topicName.get, ValueSchema, extracted)
    }

    assert(consumer != null) // should be initialised as part of start()
    Try {
      Option(consumer.receive).map { msg =>
        logger.info("Received message {}", msg)

        // This operation is not threadsafe as a result the plugin is not threadsafe.
        // However KafkaConnect assigns a single thread to each task and the poll
        // method is always called by a single thread.
        unAcknowledgedMessages = unAcknowledgedMessages.updated(msg.getJMSMessageID, msg)

        toRecord(msg)
      }.toSeq
    }.recover {
      case NonFatal(e) =>
        logger.error("Exception while processing message", e)
        List.empty
    }.get.asJava
  }

  @throws(classOf[InterruptedException])
  override def commitRecord(record: SourceRecord): Unit = {
    val msgId = record.sourceOffset().get(MessageId).asInstanceOf[String]
    val maybeMsg = unAcknowledgedMessages.get(msgId)
    maybeMsg.foreach(_.acknowledge())
    unAcknowledgedMessages = unAcknowledgedMessages - msgId
  }

  def stop() {
    logger.debug("Stopping task")
    synchronized {
      unAcknowledgedMessages = Map()
      try {
        if (consumer != null) {
          consumer.close()
          logger.debug("Closed input stream")
        }
      }
      catch {
        case NonFatal(e) => logger.error("Failed to close consumer stream: ", e)
      }
      this.notify()
    }
  }
}