package com.amazonaws.services.sqs;

import static com.amazonaws.services.sqs.util.SQSQueueUtils.booleanMessageAttributeValue;
import static com.amazonaws.services.sqs.util.SQSQueueUtils.getBooleanMessageAttributeValue;
import static com.amazonaws.services.sqs.util.SQSQueueUtils.getStringMessageAttributeValue;
import static com.amazonaws.services.sqs.util.SQSQueueUtils.stringMessageAttributeValue;
import static com.amazonaws.util.StringUtils.UTF8;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.AbstractExecutorService;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

import com.amazonaws.services.sqs.executors.Base64Serializer;
import com.amazonaws.services.sqs.executors.CompletedFutureSerializer;
import com.amazonaws.services.sqs.executors.Deduplicated;
import com.amazonaws.services.sqs.executors.DefaultSerializer;
import com.amazonaws.services.sqs.executors.InvertibleFunction;
import com.amazonaws.services.sqs.executors.SerializableRunnable;
import com.amazonaws.services.sqs.model.Message;
import com.amazonaws.services.sqs.model.SendMessageRequest;
import com.amazonaws.services.sqs.util.SQSMessageConsumer;
import com.amazonaws.services.sqs.util.SQSMessageConsumerBuilder;
import com.amazonaws.services.sqs.util.SQSQueueUtils;
import com.amazonaws.util.BinaryUtils;
import com.amazonaws.util.Md5Utils;

// TODO-RS: Factor out deduplication implementation?
class SQSExecutorService extends AbstractExecutorService {

    // TODO-RS: Configuration
    private static final int MAX_WAIT_TIME_SECONDS = 60;

    private static final long DEDUPLICATION_WINDOW_MILLIS = TimeUnit.MILLISECONDS.convert(20, TimeUnit.MINUTES);

    protected final InvertibleFunction<Object, String> serializer;

    protected final AmazonSQS sqs;
    protected final AmazonSQSRequester sqsRequester;
    protected final AmazonSQSResponder sqsResponder;
    protected final String queueUrl;
    private final SQSMessageConsumer messageConsumer;

    private final ScheduledExecutorService dedupedResultPoller = Executors.newScheduledThreadPool(1);

    private final AtomicBoolean shuttingDown = new AtomicBoolean(false);

    public SQSExecutorService(AmazonSQSRequester sqsRequester, AmazonSQSResponder sqsResponder, String queueUrl, Consumer<Exception> exceptionHandler) {
        this(sqsRequester, sqsResponder, queueUrl, DefaultSerializer.INSTANCE.andThen(Base64Serializer.INSTANCE), exceptionHandler);
    }

    public SQSExecutorService(AmazonSQSRequester sqsRequester, AmazonSQSResponder sqsResponder, String queueUrl, InvertibleFunction<Object, String> serializer, Consumer<Exception> exceptionHandler) {
        this.sqs = sqsRequester.getAmazonSQS();
        this.sqsRequester = sqsRequester;
        this.sqsResponder = sqsResponder;
        this.queueUrl = queueUrl;
        this.messageConsumer = SQSMessageConsumerBuilder.standard()
                                                        .withAmazonSQS(this.sqs)
                                                        .withQueueUrl(queueUrl)
                                                        .withConsumer(this::accept)
                                                        .withExceptionHandler(exceptionHandler)
                                                        .build();
        this.messageConsumer.start();
        this.serializer = serializer;
    }

    public String getQueueUrl() {
        return queueUrl;
    }
    
    @Override
    public void execute(Runnable runnable) {
        if (isShutdown()) {
            throw new RejectedExecutionException("Task " + runnable.toString() +
                    " rejected from " + this.toString());
        }
        convert(runnable).send();
    }

    public void execute(SerializableRunnable runnable) {
        execute((Runnable)runnable);
    }

    // TODO-RS: Local repeating task to remove these when expired?
    private static class Metadata {
        private final Optional<String> deduplicationID;
        private String uuid;
        private final long expiry;
        private Optional<String> serializedResult;

        public Metadata(Optional<String> deduplicationID, String uuid) {
            // TODO-RS: Clock drift!
            this(deduplicationID, uuid, System.currentTimeMillis() + DEDUPLICATION_WINDOW_MILLIS, Optional.empty());
        }

        private Metadata(Optional<String> deduplicationID, String uuid, long expiry, Optional<String> serializedResult) {
            this.deduplicationID = deduplicationID;
            this.uuid = uuid;
            this.expiry = expiry;
            this.serializedResult = serializedResult;
        }

        public static Metadata fromTag(String serialized) {
            String[] parts = serialized.split(":");
            // TODO-RS: Clean up serialization. Tags can't have arbitrary characters and can't be too large!
            return new Metadata(Optional.of(parts[0]).filter(s -> !"null".equals(s)), 
                                parts[1], 
                                Long.parseLong(parts[2]),
                                Optional.of(parts[3]).filter(s -> !"null".equals(s)));
        }

        public static Metadata fromMessageContent(MessageContent messageContent) {
            String uuid = getStringMessageAttributeValue(messageContent.getMessageAttributes(), SQSFutureTask.UUID_ATTRIBUTE_NAME).get();
            Optional<String> deduplicationID = getStringMessageAttributeValue(messageContent.getMessageAttributes(), SQSFutureTask.DEDUPLICATION_ID_ATTRIBUTE_NAME);
            return new Metadata(deduplicationID, uuid);
        }

        public boolean shouldNotRun(SQSFutureTask<?> task) {
            if (isExpired()) {
                return false;
            } else if (isDuplicate(task)) {
                return true;
            } else if (serializedResult.isPresent()) {
                return true;
            } else {
                return false;
            }
        }

        public boolean isDuplicate(SQSFutureTask<?> task) {
            return deduplicationID.isPresent() && !uuid.equals(task.metadata.uuid);
        }

        public boolean isExpired() {
            // TODO-RS: Leverage SentTimestamp and ApproximateFirstReceiveTimestamp
            // to fight clock drift.
            return System.currentTimeMillis() > expiry;
        }

        public void saveToTag(AmazonSQS sqs, String queueUrl) {
            String key = deduplicationID.orElse(uuid);
            sqs.tagQueue(queueUrl, Collections.singletonMap(key, toString()));
        }

        @Override
        public String toString() {
            // TODO-RS: clean up serialization
            StringBuilder builder = new StringBuilder();
            builder.append(deduplicationID.orElse("null"));
            builder.append(':');
            builder.append(uuid);
            builder.append(':');
            builder.append(expiry);
            builder.append(':');
            builder.append(serializedResult.orElse("null"));
            return builder.toString();
        }
    }

    @Override
    protected <T> RunnableFuture<T> newTaskFor(Callable<T> callable) {
        MessageContent messageContent = toMessageContent(callable);
        addDeduplicationAttributes(messageContent, callable);
        return new SQSFutureTask<>(callable, messageContent, true);
    }

    @Override
    protected <T> RunnableFuture<T> newTaskFor(Runnable runnable, T value) {
        return newTaskFor(runnable, value, true);
    }

    private <T> SQSFutureTask<T> newTaskFor(Runnable runnable, T value, boolean withResponse) {
        MessageContent messageContent = toMessageContent(runnable);
        addDeduplicationAttributes(messageContent, runnable);
        return new SQSFutureTask<>(Executors.callable(runnable, value), messageContent, withResponse);
    }

    protected void accept(Message message) {
        deserializeTask(message).run();
    }

    protected SQSFutureTask<?> deserializeTask(Message message) {
        return new SQSFutureTask<>(message);
    }

    public SQSFutureTask<?> convert(Runnable runnable) {
        if (runnable instanceof SQSFutureTask<?>) {
            return (SQSFutureTask<?>)runnable;
        } else {
            return newTaskFor(runnable, null, false);
        }
    }

    protected MessageContent toMessageContent(Runnable runnable) {
        Objects.requireNonNull(runnable);

        MessageContent messageContent = new MessageContent(serializer.apply(runnable));
        addDeduplicationAttributes(messageContent, runnable);
        return messageContent;
    }

    protected MessageContent toMessageContent(Callable<?> callable) {
        Objects.requireNonNull(callable);

        MessageContent messageContent = new MessageContent(serializer.apply(callable));
        messageContent.setMessageAttributesEntry(SQSFutureTask.IS_CALLABLE_ATTRIBUTE_NAME, booleanMessageAttributeValue(true));
        addDeduplicationAttributes(messageContent, callable);
        return messageContent;
    }

    private void addDeduplicationAttributes(MessageContent messageContent, Object task) {
        if (task instanceof Deduplicated) {
            String deduplicationID = ((Deduplicated)task).deduplicationID();
            if (deduplicationID == null) {
                String body = messageContent.getMessageBody();
                deduplicationID = BinaryUtils.toHex(Md5Utils.computeMD5Hash(body.getBytes(UTF8)));
            }
            messageContent.setMessageAttributesEntry(SQSFutureTask.DEDUPLICATION_ID_ATTRIBUTE_NAME,
                    stringMessageAttributeValue(deduplicationID));
        }
        messageContent.setMessageAttributesEntry(SQSFutureTask.UUID_ATTRIBUTE_NAME,
                stringMessageAttributeValue(UUID.randomUUID().toString()));
    }

    @SuppressWarnings("unchecked")
    private <T> Callable<T> callableFromMessage(Message message) {
        Object deserialized = serializer.unapply(message.getBody());
        boolean isCallable = getBooleanMessageAttributeValue(message.getMessageAttributes(), SQSFutureTask.IS_CALLABLE_ATTRIBUTE_NAME);
        if (isCallable) {
            return (Callable<T>)deserialized;
        } else {
            return Executors.callable((Runnable)deserialized, null);
        }
    }

    // TODO-RS: Make this static?
    protected class SQSFutureTask<T> extends FutureTask<T> {

        private static final String DEDUPLICATION_ID_ATTRIBUTE_NAME = "DeduplicationID";
        private static final String UUID_ATTRIBUTE_NAME = "UUID";
        private static final String IS_CALLABLE_ATTRIBUTE_NAME = "IsCallable";

        private final Metadata metadata;

        private final boolean withResponse;
        protected final InvertibleFunction<Future<T>, String> futureSerializer = new CompletedFutureSerializer<>(serializer);

        protected final MessageContent messageContent;

        // TODO-RS: The result will come either from a response message or
        // polling the deduplication metadata on the tags.
        // Is there a good way to have the same thread pool do one or the other?
        private Optional<Future<?>> resultFuture;

        public SQSFutureTask(Callable<T> callable, MessageContent messageContent, boolean withResponse) {
            super(callable);
            this.messageContent = messageContent;
            this.withResponse = withResponse;
            this.metadata = Metadata.fromMessageContent(messageContent);
            this.resultFuture = Optional.empty();
        }

        public SQSFutureTask(Message message) {
            super(callableFromMessage(message));
            this.messageContent = MessageContent.fromMessage(message);
            this.withResponse = false;
            this.metadata = Metadata.fromMessageContent(messageContent);
            this.resultFuture = Optional.empty();
        }

        private Optional<Metadata> getMetadataFromTags() {
            Map<String, String> tags = sqs.listQueueTags(queueUrl).getTags();
            return Optional.ofNullable(tags.get(metadata.deduplicationID.orElse(metadata.uuid)))
                    .map(Metadata::fromTag);
        }

        protected void send() {
            if (getMetadataFromTags().filter(existingMetadata -> existingMetadata.shouldNotRun(this)).isPresent()) {
                if (withResponse) {
                    // This will immediately complete the future and cancel itself if the metadata
                    // already has the result set.
                    resultFuture = Optional.of(dedupedResultPoller.scheduleWithFixedDelay(
                            this::pollForResultFromMetadata, 0, 2, TimeUnit.SECONDS));
                }
                return;
            }

            SendMessageRequest request = toSendMessageRequest();

            if (withResponse) {
                CompletableFuture<Message> responseFuture = sqsRequester.sendMessageAndGetResponseAsync(
                        request, MAX_WAIT_TIME_SECONDS, TimeUnit.SECONDS);
                responseFuture.whenComplete((result, exception) -> {
                    if (exception != null) {
                        setException(exception);
                    } else {
                        setFromResponse(result.getBody());
                    }
                });
                this.resultFuture = Optional.of(responseFuture);
            } else {
                sqs.sendMessage(request);
            }

            // Tag afterwards, so that the race condition will result in duplicate receives rather than
            // potentially deduping all copies.
            if (metadata.deduplicationID.isPresent()) {
                metadata.saveToTag(sqs, queueUrl);
            }
        }

        public SendMessageRequest toSendMessageRequest() {
            return messageContent.toSendMessageRequest().withQueueUrl(queueUrl);
        }

        private void pollForResultFromMetadata() {
            Optional<Metadata> tagMetadata = getMetadataFromTags();
            if (tagMetadata.isPresent()) {
                tagMetadata.get().serializedResult.ifPresent(this::setFromResponse);
            } else { 
                setException(new TimeoutException());
            }
        }

        private void setFromResponse(String serializedFuture) {
            Future<T> future = futureSerializer.unapply(serializedFuture);
            try {
                set(future.get());
            } catch (InterruptedException e) {
                // Shouldn't happen
                throw new IllegalStateException(e);
            } catch (CancellationException e) {
                cancel(false);
            } catch (ExecutionException e) {
                setException(e.getCause());
            }
        }

        @Override
        public void run() {
            Optional<Metadata> maybeMetadata = getMetadataFromTags();
            if (maybeMetadata.filter(existingMetadata -> existingMetadata.shouldNotRun(this)).isPresent()) {
                maybeMetadata.get().serializedResult.ifPresent(this::setFromResponse);
                return;
            }

            super.run();
        }

        @Override
        protected void done() {
            resultFuture.ifPresent(f -> f.cancel(false));

            String response = futureSerializer.apply(this);

            if (sqsResponder.isResponseMessageRequested(messageContent)) {
                sqsResponder.sendResponseMessage(messageContent, new MessageContent(response));
            }

            if (metadata.deduplicationID.isPresent() || isCancelled()) {
                metadata.serializedResult = Optional.of(response);
                metadata.saveToTag(sqs, queueUrl);
            }
        }
    }

    @Override
    public void shutdown() {
        messageConsumer.shutdown();
        shuttingDown.set(true);
    }

    @Override
    public List<Runnable> shutdownNow() {
        shutdown();
        return Collections.emptyList();
    }

    @Override
    public boolean isShutdown() {
        return shuttingDown.get();
    }

    @Override
    public boolean isTerminated() {
        return isShutdown() && SQSQueueUtils.isQueueEmpty(sqs, queueUrl);
    }

    @Override
    public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
        return SQSQueueUtils.awaitEmptyQueue(sqs, queueUrl, timeout, unit);
    }
}