package org.springframework.integration.aws.sqs.core;

import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

import org.apache.commons.codec.binary.Hex;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.json.JSONException;
import org.json.JSONObject;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.integration.Message;
import org.springframework.integration.MessagingException;
import org.springframework.integration.aws.AwsUtil;
import org.springframework.integration.aws.JsonMessageMarshaller;
import org.springframework.integration.aws.MessageMarshaller;
import org.springframework.integration.aws.MessageMarshallerException;
import org.springframework.integration.aws.Permission;
import org.springframework.integration.aws.sqs.SqsHeaders;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.util.Assert;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.policy.Policy;
import com.amazonaws.auth.policy.Principal;
import com.amazonaws.auth.policy.Resource;
import com.amazonaws.auth.policy.Statement;
import com.amazonaws.auth.policy.Statement.Effect;
import com.amazonaws.auth.policy.actions.SQSActions;
import com.amazonaws.auth.policy.conditions.ArnCondition;
import com.amazonaws.auth.policy.conditions.ArnCondition.ArnComparisonType;
import com.amazonaws.auth.policy.conditions.ConditionFactory;
import com.amazonaws.auth.policy.internal.JsonPolicyWriter;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClient;
import com.amazonaws.services.sqs.model.AddPermissionRequest;
import com.amazonaws.services.sqs.model.CreateQueueRequest;
import com.amazonaws.services.sqs.model.CreateQueueResult;
import com.amazonaws.services.sqs.model.DeleteMessageRequest;
import com.amazonaws.services.sqs.model.GetQueueAttributesRequest;
import com.amazonaws.services.sqs.model.GetQueueAttributesResult;
import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
import com.amazonaws.services.sqs.model.ReceiveMessageResult;
import com.amazonaws.services.sqs.model.SendMessageRequest;
import com.amazonaws.services.sqs.model.SendMessageResult;
import com.amazonaws.services.sqs.model.SetQueueAttributesRequest;
import com.amazonaws.util.Md5Utils;

/**
 * Bundles common core logic for the Sqs components.
 * 
 * @author Sayantam Dey
 * @since 1.0
 * 
 */
public class SqsExecutor implements InitializingBean, DisposableBean {

	private static final String SNS_MESSAGE_KEY = "Message";
	private static final int DEFAULT_RECV_MESG_WAIT = 20; // seconds
	private static final String QUEUE_ARN_KEY = "QueueArn";
	private static final int DEFAULT_MESSAGE_PREFETCH_COUNT = 10;

	private final Log log = LogFactory.getLog(SqsExecutor.class);

	private String queueName;
	private BlockingQueue<String> queue;
	private AWSCredentialsProvider awsCredentialsProvider;
	private AmazonSQS sqsClient;
	private String queueUrl;
	private String queueArn;
	private String regionId;
	private int receiveMessageWaitTimeout;
	private int prefetchCount;
	private final BlockingQueue<com.amazonaws.services.sqs.model.Message> prefetchQueue;
	private Integer messageDelay;
	private Integer maximumMessageSize;
	private Integer messageRetentionPeriod;
	private Integer visibilityTimeout;
	private ClientConfiguration awsClientConfiguration;
	private MessageMarshaller messageMarshaller;

	private volatile int destroyWaitTime;
	private Set<Permission> permissions;

	/**
	 * Constructor.
	 */
	public SqsExecutor() {
		this.receiveMessageWaitTimeout = DEFAULT_RECV_MESG_WAIT;
		this.destroyWaitTime = 0;
		this.prefetchCount = DEFAULT_MESSAGE_PREFETCH_COUNT;
		this.prefetchQueue = new LinkedBlockingQueue<com.amazonaws.services.sqs.model.Message>(
				prefetchCount);
	}

	/**
	 * Verifies and sets the parameters. E.g. initializes the to be used
	 */
	@Override
	public void afterPropertiesSet() {

		Assert.isTrue(this.queueName != null || this.queueUrl != null,
				"Either queueName or queueUrl must not be empty.");

		Assert.isTrue(queue != null || awsCredentialsProvider != null,
				"Either queue or awsCredentialsProvider needs to be provided");

		if (messageMarshaller == null) {
			messageMarshaller = new JsonMessageMarshaller();
		}

		if (queue == null) {
			if (sqsClient == null) {
				if (awsClientConfiguration == null) {
					sqsClient = new AmazonSQSClient(awsCredentialsProvider);
				} else {
					sqsClient = new AmazonSQSClient(awsCredentialsProvider,
							awsClientConfiguration);
				}
			}
			if (regionId != null) {
				sqsClient.setEndpoint(String.format("sqs.%s.amazonaws.com",
						regionId));
			}
			if (queueName != null) {
				createQueueIfNotExists();
			}

			addPermissions();
		}
	}

	private void createQueueIfNotExists() {
		for (String qUrl : sqsClient.listQueues().getQueueUrls()) {
			if (qUrl.contains(queueName)) {
				queueUrl = qUrl;
				break;
			}
		}
		if (queueUrl == null) {
			CreateQueueRequest request = new CreateQueueRequest(queueName);
			Map<String, String> queueAttributes = new HashMap<String, String>();
			queueAttributes.put("ReceiveMessageWaitTimeSeconds", Integer
					.valueOf(receiveMessageWaitTimeout).toString());
			if (messageDelay != null) {
				queueAttributes.put("DelaySeconds", messageDelay.toString());
			}
			if (maximumMessageSize != null) {
				queueAttributes.put("MaximumMessageSize",
						maximumMessageSize.toString());
			}
			if (messageRetentionPeriod != null) {
				queueAttributes.put("MessageRetentionPeriod",
						messageRetentionPeriod.toString());
			}
			if (visibilityTimeout != null) {
				queueAttributes.put("VisibilityTimeout",
						visibilityTimeout.toString());
			}
			request.setAttributes(queueAttributes);
			CreateQueueResult result = sqsClient.createQueue(request);
			queueUrl = result.getQueueUrl();
			log.debug("New queue available at: " + queueUrl);
		} else {
			log.debug("Queue already exists: " + queueUrl);
		}

		resolveQueueArn();
	}

	private void resolveQueueArn() {
		GetQueueAttributesRequest request = new GetQueueAttributesRequest(
				queueUrl);
		GetQueueAttributesResult result = sqsClient.getQueueAttributes(request
				.withAttributeNames(Collections.singletonList(QUEUE_ARN_KEY)));
		queueArn = result.getAttributes().get(QUEUE_ARN_KEY);
	}

	private void addPermissions() {
		if (permissions != null && permissions.isEmpty() == false) {
			GetQueueAttributesResult result = sqsClient
					.getQueueAttributes(new GetQueueAttributesRequest(queueUrl,
							Arrays.asList("Policy")));

			AwsUtil.addPermissions(result.getAttributes(), permissions,
					new AwsUtil.AddPermissionHandler() {

						@Override
						public void execute(Permission p) {
							sqsClient.addPermission(new AddPermissionRequest()
									.withQueueUrl(queueUrl)
									.withLabel(p.getLabel())
									.withAWSAccountIds(p.getAwsAccountIds())
									.withActions(p.getActions()));
						}
					});
		}
	}

	/**
	 * Executes the outbound Sqs Operation.
	 * 
	 */
	public Object executeOutboundOperation(final Message<?> message) {

		try {
			String serializedMessage = messageMarshaller.serialize(message);
			if (queue == null) {
				SendMessageRequest request = new SendMessageRequest(queueUrl,
						serializedMessage);
				SendMessageResult result = sqsClient.sendMessage(request);
				log.debug("Message sent, Id:" + result.getMessageId());
			} else {
				queue.add(serializedMessage);
			}
		} catch (MessageMarshallerException e) {
			log.error(e.getMessage(), e);
			throw new MessagingException(e.getMessage(), e.getCause());
		}

		return message.getPayload();
	}

	/**
	 * Execute the Sqs operation. Delegates to {@link SqsExecutor#poll(Message)}
	 * .
	 */
	public Message<?> poll() {
		return poll(0);
	}

	/**
	 * Execute a retrieving (polling) Sqs operation.
	 * 
	 * @param timeout
	 *            time to wait for a message to return.
	 * 
	 * @return The payload object, which may be null.
	 */
	public Message<?> poll(long timeout) {

		Message<?> message = null;
		String payloadJSON = null;
		com.amazonaws.services.sqs.model.Message qMessage = null;
		int timeoutSeconds = (timeout > 0 ? ((int) (timeout / 1000))
				: receiveMessageWaitTimeout);
		destroyWaitTime = timeoutSeconds;
		try {
			if (queue == null) {
				if (prefetchQueue.isEmpty()) {
					ReceiveMessageRequest request = new ReceiveMessageRequest(
							queueUrl).withWaitTimeSeconds(timeoutSeconds)
							.withMaxNumberOfMessages(prefetchCount)
							.withAttributeNames("All");

					ReceiveMessageResult result = sqsClient
							.receiveMessage(request);
					for (com.amazonaws.services.sqs.model.Message sqsMessage : result
							.getMessages()) {
						prefetchQueue.offer(sqsMessage);
					}
					qMessage = prefetchQueue.poll();
				} else {
					qMessage = prefetchQueue.remove();
				}
				if (qMessage != null) {
					payloadJSON = qMessage.getBody();
					// MD5 verification
					try {
						byte[] computedHash = Md5Utils
								.computeMD5Hash(payloadJSON.getBytes("UTF-8"));
						String hexDigest = new String(
								Hex.encodeHex(computedHash));
						if (!hexDigest.equals(qMessage.getMD5OfBody())) {
							payloadJSON = null; // ignore this message
							log.warn("Dropped message due to MD5 checksum failure");
						}
					} catch (Exception e) {
						log.warn(
								"Failed to verify MD5 checksum: "
										+ e.getMessage(), e);
					}
				}
			} else {
				try {
					payloadJSON = queue.poll(timeoutSeconds, TimeUnit.SECONDS);
				} catch (InterruptedException e) {
					log.warn(e.getMessage(), e);
				}
			}
			if (payloadJSON != null) {
				JSONObject qMessageJSON = new JSONObject(payloadJSON);
				if (qMessageJSON.has(SNS_MESSAGE_KEY)) { // posted from SNS
					payloadJSON = qMessageJSON.getString(SNS_MESSAGE_KEY);
					// XXX: other SNS attributes?
				}
				Message<?> packet = null;
				try {
					packet = messageMarshaller.deserialize(payloadJSON);
				} catch (MessageMarshallerException marshallingException) {
					throw new MessagingException(
							marshallingException.getMessage(),
							marshallingException.getCause());
				}
				MessageBuilder<?> builder = MessageBuilder.fromMessage(packet);
				if (qMessage != null) {
					builder.setHeader(SqsHeaders.MSG_RECEIPT_HANDLE,
							qMessage.getReceiptHandle());
					builder.setHeader(SqsHeaders.AWS_MESSAGE_ID,
							qMessage.getMessageId());
					for (Map.Entry<String, String> e : qMessage.getAttributes()
							.entrySet()) {
						if (e.getKey().equals("ApproximateReceiveCount")) {
							builder.setHeader(SqsHeaders.RECEIVE_COUNT,
									Integer.valueOf(e.getValue()));
						} else if (e.getKey().equals("SentTimestamp")) {
							builder.setHeader(SqsHeaders.SENT_AT,
									new Date(Long.valueOf(e.getValue())));
						} else if (e.getKey().equals(
								"ApproximateFirstReceiveTimestamp")) {
							builder.setHeader(SqsHeaders.FIRST_RECEIVED_AT,
									new Date(Long.valueOf(e.getValue())));
						} else if (e.getKey().equals("SenderId")) {
							builder.setHeader(SqsHeaders.SENDER_AWS_ID,
									e.getValue());
						} else {
							builder.setHeader(e.getKey(), e.getValue());
						}
					}
				} else {
					builder.setHeader(SqsHeaders.MSG_RECEIPT_HANDLE, "");
					// to satisfy test conditions
				}

				message = builder.build();
			}

		} catch (JSONException e) {
			log.warn(e.getMessage(), e);

		} finally {
			destroyWaitTime = 0;
		}

		return message;
	}

	public String acknowlegdeReceipt(Message<?> message) {
		String receiptHandle = (String) message.getHeaders().get(
				SqsHeaders.MSG_RECEIPT_HANDLE);
		if (sqsClient != null && receiptHandle != null
				&& !receiptHandle.isEmpty()) {
			sqsClient.deleteMessage(new DeleteMessageRequest(queueUrl,
					receiptHandle));
		}

		return receiptHandle;
	}

	public String getQueueArn() {
		if (queueArn == null) {
			resolveQueueArn();
		}
		return queueArn;
	}

	public String getQueueUrl() {
		return queueUrl;
	}

	/**
	 * Example property to illustrate usage of properties in Spring Integration
	 * components. Replace with your own logic.
	 * 
	 * @param queueName
	 *            Must not be null
	 */
	public void setQueueName(String queueName) {
		Assert.hasText(queueName, "queueName must be neither null nor empty");
		this.queueName = queueName;
	}

	/**
	 * Set the queue implementation. Useful for testing the queue without
	 * actually invoking AWS.
	 * 
	 * @param queue
	 */
	public void setQueue(BlockingQueue<String> queue) {
		this.queue = queue;
	}

	/**
	 * Sets the AWS client configuration.
	 * 
	 * @param awsClientConfiguration
	 */
	public void setAwsClientConfiguration(
			ClientConfiguration awsClientConfiguration) {
		this.awsClientConfiguration = awsClientConfiguration;
	}

	/**
	 * Sets the AWS credentials provider.
	 * 
	 * @param awsCredentialsProvider
	 */
	public void setAwsCredentialsProvider(
			AWSCredentialsProvider awsCredentialsProvider) {
		this.awsCredentialsProvider = awsCredentialsProvider;
	}

	public int getReceiveMessageWaitTimeout() {
		return receiveMessageWaitTimeout;
	}

	/**
	 * Sets the timeout (in seconds) for a receive message operation, defaults
	 * to {@value #DEFAULT_RECV_MESG_WAIT} seconds.
	 * 
	 * @param receiveMessageWaitTimeout
	 */
	public void setReceiveMessageWaitTimeout(int receiveMessageWaitTimeout) {
		Assert.isTrue(receiveMessageWaitTimeout >= 0
				&& receiveMessageWaitTimeout <= 20,
				"'receiveMessageWaitTimeout' must be an integer from 0 to 20 (seconds).");
		this.receiveMessageWaitTimeout = receiveMessageWaitTimeout;
	}

	/**
	 * Sets the AWS region ID, defaults to us-east.
	 * 
	 * @param regionId
	 */
	public void setRegionId(String regionId) {
		this.regionId = regionId;
	}

	/**
	 * Sets the number of messages to prefetch, defaults to
	 * {@value #DEFAULT_MESSAGE_PREFETCH_COUNT}.
	 * 
	 * @param prefetchCount
	 */
	public void setPrefetchCount(int prefetchCount) {
		Assert.isTrue(prefetchCount >= 0 && prefetchCount <= 10,
				"'prefetchCount' must be an integer from 0 to 10.");
		this.prefetchCount = prefetchCount;
	}

	/**
	 * Sets the message delivery delay from SQS. By default there is no delay.
	 * 
	 * @param messageDelay
	 */
	public void setMessageDelay(Integer messageDelay) {
		Assert.isTrue(messageDelay >= 0 && messageDelay <= 900,
				"'messageDelay' must be an integer from 0 to 900 (15 minutes).");
		this.messageDelay = messageDelay;
	}

	/**
	 * Sets the maximum message size.
	 * 
	 * @param maximumMessageSize
	 */
	public void setMaximumMessageSize(Integer maximumMessageSize) {
		Assert.isTrue(
				maximumMessageSize >= 1024 && maximumMessageSize <= 65536,
				"'maximumMessageSize' must be an integer from 1024 bytes (1 KiB) up to 65536 bytes (64 KiB).");
		this.maximumMessageSize = maximumMessageSize;
	}

	/**
	 * Sets the message retention period at SQS. Messages older than this will
	 * be automatically be dropped by SQS.
	 * 
	 * @param messageRetentionPeriod
	 */
	public void setMessageRetentionPeriod(Integer messageRetentionPeriod) {
		Assert.isTrue(
				messageRetentionPeriod >= 60
						&& messageRetentionPeriod <= 1209600,
				"'messageRetentionPeriod' must be an integer representing seconds, from 60 (1 minute) to 1209600 (14 days)");
		this.messageRetentionPeriod = messageRetentionPeriod;
	}

	/**
	 * Sets the visibility timeout in seconds. SQS must receive an
	 * acknowledgment before this timeout occurs or else the message is
	 * re-delivered.
	 * 
	 * @param visibilityTimeout
	 */
	public void setVisibilityTimeout(Integer visibilityTimeout) {
		Assert.isTrue(
				visibilityTimeout >= 0 && visibilityTimeout <= 43200,
				"'visibilityTimeout' must be an integer representing seconds, from 0 to 43200 (12 hours)");
		this.visibilityTimeout = visibilityTimeout;
	}

	@Override
	public void destroy() throws Exception {
		if (sqsClient != null) {
			if (destroyWaitTime > 0) {
				Thread.sleep(destroyWaitTime * 1000);
			}
			sqsClient.shutdown();
		}

	}

	public void addSnsPublishPolicy(String topicName, String topicArn) {

		if (queueArn == null) {
			resolveQueueArn();
		}

		String publishPolicyKey = String.format("SNS-%s-SQS-%s", topicName,
				queueName);
		String policyId = null;
		GetQueueAttributesRequest getAttrRequest = new GetQueueAttributesRequest(
				queueUrl);
		getAttrRequest.setAttributeNames(Collections.singletonList("Policy"));
		GetQueueAttributesResult result = sqsClient
				.getQueueAttributes(getAttrRequest);
		Map<String, String> attributes = result.getAttributes();
		String policyStr = attributes.get("Policy");
		log.debug("Policy:" + policyStr);
		if (policyStr != null) {
			try {
				JSONObject policyJSON = new JSONObject(policyStr);
				policyId = policyJSON.getString("Id");
			} catch (JSONException e) {
				log.error(e.getMessage(), e);
			}
		}

		if (policyId == null || !policyId.equals(publishPolicyKey)) {
			Statement statement = new Statement(Effect.Allow);
			statement
					.withActions(SQSActions.SendMessage)
					.withPrincipals(Principal.AllUsers)
					.withResources(new Resource(queueArn))
					.withConditions(
							new ArnCondition(ArnComparisonType.ArnEquals,
									ConditionFactory.SOURCE_ARN_CONDITION_KEY,
									topicArn));
			Policy policy = new Policy();
			policy.setId(publishPolicyKey);
			policy.setStatements(Collections.singletonList(statement));

			SetQueueAttributesRequest request = new SetQueueAttributesRequest();
			request.setQueueUrl(queueUrl);
			String policyJSON = (new JsonPolicyWriter())
					.writePolicyToString(policy);
			log.debug(policyJSON);
			request.setAttributes(Collections
					.singletonMap("Policy", policyJSON));

			sqsClient.setQueueAttributes(request);
		}
	}

	public void setSqsClient(AmazonSQS sqsClient) {
		this.sqsClient = sqsClient;
	}

	public void setMessageMarshaller(MessageMarshaller messageMarshaller) {
		this.messageMarshaller = messageMarshaller;
	}

	public void setPermissions(Set<Permission> permissions) {
		this.permissions = permissions;
	}

	public void setQueueUrl(String queueUrl) {
		this.queueUrl = queueUrl;
	}

}