/**
 * Copyright (C) 2016-2019 Expedia, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hotels.road.client.partitioning;

import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.stream.Collectors.toList;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.function.BiConsumer;
import java.util.function.Function;

import lombok.extern.slf4j.Slf4j;

/**
 * A Function that given a delegate that takes collections of messages and returns a collection of response will take
 * individual messages and asynchronously return responses as {@link CompletableFuture}s. Passed messages will be passed
 * to the delegate in batches containing all the messages that arrived whilst the last batch was being executed. If the
 * processing time per message is on aggregate faster in larger batches then this function will improve efficiency.
 * There is a maximum batch size to prevent the back pressure causing the batch size to grow uncontrollably. When the
 * number of messages in the next batch exceeds the maximum batch size this function will block until execution on the
 * currently waiting batch starts. If the number of waiting messages is below the maximum threshold then this function
 * returns immediately.
 *
 * @param <MESSAGE> The Function's from type
 * @param <RESPONSE> The Function's to type
 */
@Slf4j
class MessageBatcher<MESSAGE, RESPONSE> implements CloseableFunction<MESSAGE, CompletableFuture<RESPONSE>> {
  private final BlockingQueue<MessageWithCallback<MESSAGE, RESPONSE>> queue;
  private final int maxBatchSize;
  private final Function<List<MESSAGE>, List<RESPONSE>> batchHandler;
  private final EnqueueBehaviour enqueueBehaviour;

  private boolean shutdownFlag = false;
  private final Thread thread;

  public MessageBatcher(
      int bufferSize,
      int maxBatchSize,
      EnqueueBehaviour enqueueBehaviour,
      Function<List<MESSAGE>, List<RESPONSE>> batchHandler) {
    this(bufferSize, maxBatchSize, enqueueBehaviour, batchHandler,
        r -> new Thread(r, "batcher-" + batchHandler.toString()));
  }

  MessageBatcher(
      int bufferSize,
      int maxBatchSize,
      EnqueueBehaviour enqueueBehaviour,
      Function<List<MESSAGE>, List<RESPONSE>> batchHandler,
      ThreadFactory threadFactory) {
    if (bufferSize < maxBatchSize) {
      throw new IllegalArgumentException("maxBatchSize must be less than or equal to bufferSize");
    }
    this.queue = new LinkedBlockingQueue<>(bufferSize);
    this.maxBatchSize = maxBatchSize;
    this.batchHandler = batchHandler;
    this.enqueueBehaviour = enqueueBehaviour;

    thread = threadFactory.newThread(this::processMessages);
    thread.start();
  }

  @Override
  public CompletableFuture<RESPONSE> apply(MESSAGE t) {
    try {
      return enqueueBehaviour.enqueueMessage(queue, t);
    } catch (Exception e) {
      CompletableFuture<RESPONSE> callback = new CompletableFuture<>();
      callback.completeExceptionally(e);
      return callback;
    }
  }

  @Override
  public void close() throws InterruptedException {
    shutdownFlag = true;
    thread.join();
  }

  private void processMessages() {
    while (!shutdownFlag) {
      try {
        List<MessageWithCallback<MESSAGE, RESPONSE>> buffer = new ArrayList<>();
        MessageWithCallback<MESSAGE, RESPONSE> message = queue.poll(100, MILLISECONDS);
        if (message != null) {
          buffer.add(message);
          queue.drainTo(buffer, maxBatchSize - 1);
          handleBatch(buffer);
        }
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      } catch (Exception e) {
        log.warn("Unhandled exception in message sending thread. Thread shutting down", e);
        throw e;
      }
    }
  }

  private void handleBatch(List<MessageWithCallback<MESSAGE, RESPONSE>> batch) {
    try {
      List<MESSAGE> messages = batch.stream().map(MessageWithCallback::getMessage).collect(toList());
      List<RESPONSE> responses = batchHandler.apply(messages);
      zipConsume(batch, responses, (message, response) -> message.getCallback().complete(response));
    } catch (Exception e) {
      failBatch(batch, e);
    }
  }

  private <A, B> void zipConsume(Iterable<A> a, Iterable<B> b, BiConsumer<A, B> consumer) {
    Iterator<A> iterA = a.iterator();
    Iterator<B> iterB = b.iterator();
    while (iterA.hasNext() && iterB.hasNext()) {
      consumer.accept(iterA.next(), iterB.next());
    }
  }

  private void failBatch(List<MessageWithCallback<MESSAGE, RESPONSE>> batch, Exception e) {
    for (MessageWithCallback<MESSAGE, RESPONSE> message : batch) {
      message.getCallback().completeExceptionally(e);
    }
  }
}