package org.springframework.integration.aws.sqs.core; import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import org.apache.commons.codec.binary.Hex; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.json.JSONException; import org.json.JSONObject; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; import org.springframework.integration.Message; import org.springframework.integration.MessagingException; import org.springframework.integration.aws.AwsUtil; import org.springframework.integration.aws.JsonMessageMarshaller; import org.springframework.integration.aws.MessageMarshaller; import org.springframework.integration.aws.MessageMarshallerException; import org.springframework.integration.aws.Permission; import org.springframework.integration.aws.sqs.SqsHeaders; import org.springframework.integration.support.MessageBuilder; import org.springframework.util.Assert; import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.policy.Policy; import com.amazonaws.auth.policy.Principal; import com.amazonaws.auth.policy.Resource; import com.amazonaws.auth.policy.Statement; import com.amazonaws.auth.policy.Statement.Effect; import com.amazonaws.auth.policy.actions.SQSActions; import com.amazonaws.auth.policy.conditions.ArnCondition; import com.amazonaws.auth.policy.conditions.ArnCondition.ArnComparisonType; import com.amazonaws.auth.policy.conditions.ConditionFactory; import com.amazonaws.auth.policy.internal.JsonPolicyWriter; import com.amazonaws.services.sqs.AmazonSQS; import com.amazonaws.services.sqs.AmazonSQSClient; import com.amazonaws.services.sqs.model.AddPermissionRequest; import com.amazonaws.services.sqs.model.CreateQueueRequest; import com.amazonaws.services.sqs.model.CreateQueueResult; import com.amazonaws.services.sqs.model.DeleteMessageRequest; import com.amazonaws.services.sqs.model.GetQueueAttributesRequest; import com.amazonaws.services.sqs.model.GetQueueAttributesResult; import com.amazonaws.services.sqs.model.ReceiveMessageRequest; import com.amazonaws.services.sqs.model.ReceiveMessageResult; import com.amazonaws.services.sqs.model.SendMessageRequest; import com.amazonaws.services.sqs.model.SendMessageResult; import com.amazonaws.services.sqs.model.SetQueueAttributesRequest; import com.amazonaws.util.Md5Utils; /** * Bundles common core logic for the Sqs components. * * @author Sayantam Dey * @since 1.0 * */ public class SqsExecutor implements InitializingBean, DisposableBean { private static final String SNS_MESSAGE_KEY = "Message"; private static final int DEFAULT_RECV_MESG_WAIT = 20; // seconds private static final String QUEUE_ARN_KEY = "QueueArn"; private static final int DEFAULT_MESSAGE_PREFETCH_COUNT = 10; private final Log log = LogFactory.getLog(SqsExecutor.class); private String queueName; private BlockingQueue<String> queue; private AWSCredentialsProvider awsCredentialsProvider; private AmazonSQS sqsClient; private String queueUrl; private String queueArn; private String regionId; private int receiveMessageWaitTimeout; private int prefetchCount; private final BlockingQueue<com.amazonaws.services.sqs.model.Message> prefetchQueue; private Integer messageDelay; private Integer maximumMessageSize; private Integer messageRetentionPeriod; private Integer visibilityTimeout; private ClientConfiguration awsClientConfiguration; private MessageMarshaller messageMarshaller; private volatile int destroyWaitTime; private Set<Permission> permissions; /** * Constructor. */ public SqsExecutor() { this.receiveMessageWaitTimeout = DEFAULT_RECV_MESG_WAIT; this.destroyWaitTime = 0; this.prefetchCount = DEFAULT_MESSAGE_PREFETCH_COUNT; this.prefetchQueue = new LinkedBlockingQueue<com.amazonaws.services.sqs.model.Message>( prefetchCount); } /** * Verifies and sets the parameters. E.g. initializes the to be used */ @Override public void afterPropertiesSet() { Assert.isTrue(this.queueName != null || this.queueUrl != null, "Either queueName or queueUrl must not be empty."); Assert.isTrue(queue != null || awsCredentialsProvider != null, "Either queue or awsCredentialsProvider needs to be provided"); if (messageMarshaller == null) { messageMarshaller = new JsonMessageMarshaller(); } if (queue == null) { if (sqsClient == null) { if (awsClientConfiguration == null) { sqsClient = new AmazonSQSClient(awsCredentialsProvider); } else { sqsClient = new AmazonSQSClient(awsCredentialsProvider, awsClientConfiguration); } } if (regionId != null) { sqsClient.setEndpoint(String.format("sqs.%s.amazonaws.com", regionId)); } if (queueName != null) { createQueueIfNotExists(); } addPermissions(); } } private void createQueueIfNotExists() { for (String qUrl : sqsClient.listQueues().getQueueUrls()) { if (qUrl.contains(queueName)) { queueUrl = qUrl; break; } } if (queueUrl == null) { CreateQueueRequest request = new CreateQueueRequest(queueName); Map<String, String> queueAttributes = new HashMap<String, String>(); queueAttributes.put("ReceiveMessageWaitTimeSeconds", Integer .valueOf(receiveMessageWaitTimeout).toString()); if (messageDelay != null) { queueAttributes.put("DelaySeconds", messageDelay.toString()); } if (maximumMessageSize != null) { queueAttributes.put("MaximumMessageSize", maximumMessageSize.toString()); } if (messageRetentionPeriod != null) { queueAttributes.put("MessageRetentionPeriod", messageRetentionPeriod.toString()); } if (visibilityTimeout != null) { queueAttributes.put("VisibilityTimeout", visibilityTimeout.toString()); } request.setAttributes(queueAttributes); CreateQueueResult result = sqsClient.createQueue(request); queueUrl = result.getQueueUrl(); log.debug("New queue available at: " + queueUrl); } else { log.debug("Queue already exists: " + queueUrl); } resolveQueueArn(); } private void resolveQueueArn() { GetQueueAttributesRequest request = new GetQueueAttributesRequest( queueUrl); GetQueueAttributesResult result = sqsClient.getQueueAttributes(request .withAttributeNames(Collections.singletonList(QUEUE_ARN_KEY))); queueArn = result.getAttributes().get(QUEUE_ARN_KEY); } private void addPermissions() { if (permissions != null && permissions.isEmpty() == false) { GetQueueAttributesResult result = sqsClient .getQueueAttributes(new GetQueueAttributesRequest(queueUrl, Arrays.asList("Policy"))); AwsUtil.addPermissions(result.getAttributes(), permissions, new AwsUtil.AddPermissionHandler() { @Override public void execute(Permission p) { sqsClient.addPermission(new AddPermissionRequest() .withQueueUrl(queueUrl) .withLabel(p.getLabel()) .withAWSAccountIds(p.getAwsAccountIds()) .withActions(p.getActions())); } }); } } /** * Executes the outbound Sqs Operation. * */ public Object executeOutboundOperation(final Message<?> message) { try { String serializedMessage = messageMarshaller.serialize(message); if (queue == null) { SendMessageRequest request = new SendMessageRequest(queueUrl, serializedMessage); SendMessageResult result = sqsClient.sendMessage(request); log.debug("Message sent, Id:" + result.getMessageId()); } else { queue.add(serializedMessage); } } catch (MessageMarshallerException e) { log.error(e.getMessage(), e); throw new MessagingException(e.getMessage(), e.getCause()); } return message.getPayload(); } /** * Execute the Sqs operation. Delegates to {@link SqsExecutor#poll(Message)} * . */ public Message<?> poll() { return poll(0); } /** * Execute a retrieving (polling) Sqs operation. * * @param timeout * time to wait for a message to return. * * @return The payload object, which may be null. */ public Message<?> poll(long timeout) { Message<?> message = null; String payloadJSON = null; com.amazonaws.services.sqs.model.Message qMessage = null; int timeoutSeconds = (timeout > 0 ? ((int) (timeout / 1000)) : receiveMessageWaitTimeout); destroyWaitTime = timeoutSeconds; try { if (queue == null) { if (prefetchQueue.isEmpty()) { ReceiveMessageRequest request = new ReceiveMessageRequest( queueUrl).withWaitTimeSeconds(timeoutSeconds) .withMaxNumberOfMessages(prefetchCount) .withAttributeNames("All"); ReceiveMessageResult result = sqsClient .receiveMessage(request); for (com.amazonaws.services.sqs.model.Message sqsMessage : result .getMessages()) { prefetchQueue.offer(sqsMessage); } qMessage = prefetchQueue.poll(); } else { qMessage = prefetchQueue.remove(); } if (qMessage != null) { payloadJSON = qMessage.getBody(); // MD5 verification try { byte[] computedHash = Md5Utils .computeMD5Hash(payloadJSON.getBytes("UTF-8")); String hexDigest = new String( Hex.encodeHex(computedHash)); if (!hexDigest.equals(qMessage.getMD5OfBody())) { payloadJSON = null; // ignore this message log.warn("Dropped message due to MD5 checksum failure"); } } catch (Exception e) { log.warn( "Failed to verify MD5 checksum: " + e.getMessage(), e); } } } else { try { payloadJSON = queue.poll(timeoutSeconds, TimeUnit.SECONDS); } catch (InterruptedException e) { log.warn(e.getMessage(), e); } } if (payloadJSON != null) { JSONObject qMessageJSON = new JSONObject(payloadJSON); if (qMessageJSON.has(SNS_MESSAGE_KEY)) { // posted from SNS payloadJSON = qMessageJSON.getString(SNS_MESSAGE_KEY); // XXX: other SNS attributes? } Message<?> packet = null; try { packet = messageMarshaller.deserialize(payloadJSON); } catch (MessageMarshallerException marshallingException) { throw new MessagingException( marshallingException.getMessage(), marshallingException.getCause()); } MessageBuilder<?> builder = MessageBuilder.fromMessage(packet); if (qMessage != null) { builder.setHeader(SqsHeaders.MSG_RECEIPT_HANDLE, qMessage.getReceiptHandle()); builder.setHeader(SqsHeaders.AWS_MESSAGE_ID, qMessage.getMessageId()); for (Map.Entry<String, String> e : qMessage.getAttributes() .entrySet()) { if (e.getKey().equals("ApproximateReceiveCount")) { builder.setHeader(SqsHeaders.RECEIVE_COUNT, Integer.valueOf(e.getValue())); } else if (e.getKey().equals("SentTimestamp")) { builder.setHeader(SqsHeaders.SENT_AT, new Date(Long.valueOf(e.getValue()))); } else if (e.getKey().equals( "ApproximateFirstReceiveTimestamp")) { builder.setHeader(SqsHeaders.FIRST_RECEIVED_AT, new Date(Long.valueOf(e.getValue()))); } else if (e.getKey().equals("SenderId")) { builder.setHeader(SqsHeaders.SENDER_AWS_ID, e.getValue()); } else { builder.setHeader(e.getKey(), e.getValue()); } } } else { builder.setHeader(SqsHeaders.MSG_RECEIPT_HANDLE, ""); // to satisfy test conditions } message = builder.build(); } } catch (JSONException e) { log.warn(e.getMessage(), e); } finally { destroyWaitTime = 0; } return message; } public String acknowlegdeReceipt(Message<?> message) { String receiptHandle = (String) message.getHeaders().get( SqsHeaders.MSG_RECEIPT_HANDLE); if (sqsClient != null && receiptHandle != null && !receiptHandle.isEmpty()) { sqsClient.deleteMessage(new DeleteMessageRequest(queueUrl, receiptHandle)); } return receiptHandle; } public String getQueueArn() { if (queueArn == null) { resolveQueueArn(); } return queueArn; } public String getQueueUrl() { return queueUrl; } /** * Example property to illustrate usage of properties in Spring Integration * components. Replace with your own logic. * * @param queueName * Must not be null */ public void setQueueName(String queueName) { Assert.hasText(queueName, "queueName must be neither null nor empty"); this.queueName = queueName; } /** * Set the queue implementation. Useful for testing the queue without * actually invoking AWS. * * @param queue */ public void setQueue(BlockingQueue<String> queue) { this.queue = queue; } /** * Sets the AWS client configuration. * * @param awsClientConfiguration */ public void setAwsClientConfiguration( ClientConfiguration awsClientConfiguration) { this.awsClientConfiguration = awsClientConfiguration; } /** * Sets the AWS credentials provider. * * @param awsCredentialsProvider */ public void setAwsCredentialsProvider( AWSCredentialsProvider awsCredentialsProvider) { this.awsCredentialsProvider = awsCredentialsProvider; } public int getReceiveMessageWaitTimeout() { return receiveMessageWaitTimeout; } /** * Sets the timeout (in seconds) for a receive message operation, defaults * to {@value #DEFAULT_RECV_MESG_WAIT} seconds. * * @param receiveMessageWaitTimeout */ public void setReceiveMessageWaitTimeout(int receiveMessageWaitTimeout) { Assert.isTrue(receiveMessageWaitTimeout >= 0 && receiveMessageWaitTimeout <= 20, "'receiveMessageWaitTimeout' must be an integer from 0 to 20 (seconds)."); this.receiveMessageWaitTimeout = receiveMessageWaitTimeout; } /** * Sets the AWS region ID, defaults to us-east. * * @param regionId */ public void setRegionId(String regionId) { this.regionId = regionId; } /** * Sets the number of messages to prefetch, defaults to * {@value #DEFAULT_MESSAGE_PREFETCH_COUNT}. * * @param prefetchCount */ public void setPrefetchCount(int prefetchCount) { Assert.isTrue(prefetchCount >= 0 && prefetchCount <= 10, "'prefetchCount' must be an integer from 0 to 10."); this.prefetchCount = prefetchCount; } /** * Sets the message delivery delay from SQS. By default there is no delay. * * @param messageDelay */ public void setMessageDelay(Integer messageDelay) { Assert.isTrue(messageDelay >= 0 && messageDelay <= 900, "'messageDelay' must be an integer from 0 to 900 (15 minutes)."); this.messageDelay = messageDelay; } /** * Sets the maximum message size. * * @param maximumMessageSize */ public void setMaximumMessageSize(Integer maximumMessageSize) { Assert.isTrue( maximumMessageSize >= 1024 && maximumMessageSize <= 65536, "'maximumMessageSize' must be an integer from 1024 bytes (1 KiB) up to 65536 bytes (64 KiB)."); this.maximumMessageSize = maximumMessageSize; } /** * Sets the message retention period at SQS. Messages older than this will * be automatically be dropped by SQS. * * @param messageRetentionPeriod */ public void setMessageRetentionPeriod(Integer messageRetentionPeriod) { Assert.isTrue( messageRetentionPeriod >= 60 && messageRetentionPeriod <= 1209600, "'messageRetentionPeriod' must be an integer representing seconds, from 60 (1 minute) to 1209600 (14 days)"); this.messageRetentionPeriod = messageRetentionPeriod; } /** * Sets the visibility timeout in seconds. SQS must receive an * acknowledgment before this timeout occurs or else the message is * re-delivered. * * @param visibilityTimeout */ public void setVisibilityTimeout(Integer visibilityTimeout) { Assert.isTrue( visibilityTimeout >= 0 && visibilityTimeout <= 43200, "'visibilityTimeout' must be an integer representing seconds, from 0 to 43200 (12 hours)"); this.visibilityTimeout = visibilityTimeout; } @Override public void destroy() throws Exception { if (sqsClient != null) { if (destroyWaitTime > 0) { Thread.sleep(destroyWaitTime * 1000); } sqsClient.shutdown(); } } public void addSnsPublishPolicy(String topicName, String topicArn) { if (queueArn == null) { resolveQueueArn(); } String publishPolicyKey = String.format("SNS-%s-SQS-%s", topicName, queueName); String policyId = null; GetQueueAttributesRequest getAttrRequest = new GetQueueAttributesRequest( queueUrl); getAttrRequest.setAttributeNames(Collections.singletonList("Policy")); GetQueueAttributesResult result = sqsClient .getQueueAttributes(getAttrRequest); Map<String, String> attributes = result.getAttributes(); String policyStr = attributes.get("Policy"); log.debug("Policy:" + policyStr); if (policyStr != null) { try { JSONObject policyJSON = new JSONObject(policyStr); policyId = policyJSON.getString("Id"); } catch (JSONException e) { log.error(e.getMessage(), e); } } if (policyId == null || !policyId.equals(publishPolicyKey)) { Statement statement = new Statement(Effect.Allow); statement .withActions(SQSActions.SendMessage) .withPrincipals(Principal.AllUsers) .withResources(new Resource(queueArn)) .withConditions( new ArnCondition(ArnComparisonType.ArnEquals, ConditionFactory.SOURCE_ARN_CONDITION_KEY, topicArn)); Policy policy = new Policy(); policy.setId(publishPolicyKey); policy.setStatements(Collections.singletonList(statement)); SetQueueAttributesRequest request = new SetQueueAttributesRequest(); request.setQueueUrl(queueUrl); String policyJSON = (new JsonPolicyWriter()) .writePolicyToString(policy); log.debug(policyJSON); request.setAttributes(Collections .singletonMap("Policy", policyJSON)); sqsClient.setQueueAttributes(request); } } public void setSqsClient(AmazonSQS sqsClient) { this.sqsClient = sqsClient; } public void setMessageMarshaller(MessageMarshaller messageMarshaller) { this.messageMarshaller = messageMarshaller; } public void setPermissions(Set<Permission> permissions) { this.permissions = permissions; } public void setQueueUrl(String queueUrl) { this.queueUrl = queueUrl; } }