/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.dag.app.rm;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.hadoop.service.AbstractService;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.records.ApplicationAccessType;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.client.api.AMRMClient;
import org.apache.hadoop.yarn.client.api.impl.AMRMClientImpl;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.rm.YarnTaskSchedulerService.CookieContainerRequest;
import org.apache.tez.dag.app.rm.TaskSchedulerService.TaskSchedulerAppCallback;
import org.apache.tez.dag.app.rm.container.ContainerSignatureMatcher;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;

class TestTaskSchedulerHelpers {

  // Mocking AMRMClientImpl to make use of getMatchingRequest
  static class AMRMClientForTest extends AMRMClientImpl<CookieContainerRequest> {

    @Override
    protected void serviceStart() {
    }

    @Override
    protected void serviceStop() {
    }
  }


  // Mocking AMRMClientAsyncImpl to make use of getMatchingRequest
  static class AMRMClientAsyncForTest extends
      TezAMRMClientAsync<CookieContainerRequest> {

    public AMRMClientAsyncForTest(
        AMRMClient<CookieContainerRequest> client,
        int intervalMs) {
      // CallbackHandler is not needed - will be called independently in the test.
      super(client, intervalMs, null);
    }

    @SuppressWarnings("unchecked")
    @Override
    public RegisterApplicationMasterResponse registerApplicationMaster(
        String appHostName, int appHostPort, String appTrackingUrl) {
      RegisterApplicationMasterResponse mockRegResponse = mock(RegisterApplicationMasterResponse.class);
      Resource mockMaxResource = mock(Resource.class);
      Map<ApplicationAccessType, String> mockAcls = mock(Map.class);
      when(mockRegResponse.getMaximumResourceCapability()).thenReturn(
          mockMaxResource);
      when(mockRegResponse.getApplicationACLs()).thenReturn(mockAcls);
      return mockRegResponse;
    }

    @Override
    public void unregisterApplicationMaster(FinalApplicationStatus appStatus,
        String appMessage, String appTrackingUrl) {
    }

    @Override
    protected void serviceStart() {
    }

    @Override
    protected void serviceStop() {
    }
  }
  
  // Overrides start / stop. Will be controlled without the extra event handling thread.
  static class TaskSchedulerEventHandlerForTest extends
      TaskSchedulerEventHandler {

    private TezAMRMClientAsync<CookieContainerRequest> amrmClientAsync;
    private ContainerSignatureMatcher containerSignatureMatcher;

    @SuppressWarnings("rawtypes")
    public TaskSchedulerEventHandlerForTest(AppContext appContext,
        EventHandler eventHandler,
        TezAMRMClientAsync<CookieContainerRequest> amrmClientAsync,
        ContainerSignatureMatcher containerSignatureMatcher) {
      super(appContext, null, eventHandler, containerSignatureMatcher);
      this.amrmClientAsync = amrmClientAsync;
      this.containerSignatureMatcher = containerSignatureMatcher;
    }

    @Override
    public TaskSchedulerService createTaskScheduler(String host, int port,
        String trackingUrl, AppContext appContext) {
      return new TaskSchedulerWithDrainableAppCallback(this,
          containerSignatureMatcher, host, port, trackingUrl, amrmClientAsync,
          appContext);
    }

    public TaskSchedulerService getSpyTaskScheduler() {
      return this.taskScheduler;
    }

    @Override
    public void serviceStart() {
      TaskSchedulerService taskSchedulerReal = createTaskScheduler("host", 0, "",
        appContext);
      // Init the service so that reuse configuration is picked up.
      ((AbstractService)taskSchedulerReal).init(getConfig());
      ((AbstractService)taskSchedulerReal).start();
      taskScheduler = spy(taskSchedulerReal);
    }

    @Override
    public void serviceStop() {
    }
  }

  @SuppressWarnings("rawtypes")
  static class CapturingEventHandler implements EventHandler {

    private List<Event> events = new LinkedList<Event>();


    public void handle(Event event) {
      events.add(event);
    }

    public void reset() {
      events.clear();
    }

    public void verifyNoInvocations(Class<? extends Event> eventClass) {
      for (Event e : events) {
        assertFalse(e.getClass().getName().equals(eventClass.getName()));
      }
    }

    public Event verifyInvocation(Class<? extends Event> eventClass) {
      for (Event e : events) {
        if (e.getClass().getName().equals(eventClass.getName())) {
          return e;
        }
      }
      fail("Expected Event: " + eventClass.getName() + " not sent");
      return null;
    }
  }

  static class TaskSchedulerWithDrainableAppCallback extends YarnTaskSchedulerService {

    private TaskSchedulerAppCallbackDrainable drainableAppCallback;

    public TaskSchedulerWithDrainableAppCallback(
        TaskSchedulerAppCallback appClient,
        ContainerSignatureMatcher containerSignatureMatcher,
        String appHostName, int appHostPort, String appTrackingUrl,
        AppContext appContext) {
      super(appClient, containerSignatureMatcher, appHostName, appHostPort,
          appTrackingUrl, appContext);
      shouldUnregister.set(true);
    }

    public TaskSchedulerWithDrainableAppCallback(
        TaskSchedulerAppCallback appClient,
        ContainerSignatureMatcher containerSignatureMatcher,
        String appHostName, int appHostPort, String appTrackingUrl,
        TezAMRMClientAsync<CookieContainerRequest> client,
        AppContext appContext) {
      super(appClient, containerSignatureMatcher, appHostName, appHostPort,
          appTrackingUrl, client, appContext);
      shouldUnregister.set(true);
    }

    @Override
    TaskSchedulerAppCallback createAppCallbackDelegate(
        TaskSchedulerAppCallback realAppClient) {
      drainableAppCallback = new TaskSchedulerAppCallbackDrainable(
          new TaskSchedulerAppCallbackWrapper(realAppClient,
              appCallbackExecutor));
      return drainableAppCallback;
    }
    
    @Override
    ExecutorService createAppCallbackExecutorService() {
      ExecutorService real = super.createAppCallbackExecutorService();
      return new CountingExecutorService(real);
    }

    public TaskSchedulerAppCallbackDrainable getDrainableAppCallback() {
      return drainableAppCallback;
    }
  }

  @SuppressWarnings("rawtypes")
  static class TaskSchedulerAppCallbackDrainable implements TaskSchedulerAppCallback {
    int completedEvents;
    int invocations;
    private TaskSchedulerAppCallback real;
    private CountingExecutorService countingExecutorService;
    final AtomicInteger count = new AtomicInteger(0);
    
    public TaskSchedulerAppCallbackDrainable(TaskSchedulerAppCallbackWrapper real) {
      countingExecutorService = (CountingExecutorService) real.executorService;
      this.real = real;
    }

    @Override
    public void taskAllocated(Object task, Object appCookie, Container container) {
      count.incrementAndGet();
      invocations++;
      real.taskAllocated(task, appCookie, container);
    }

    @Override
    public void containerCompleted(Object taskLastAllocated,
        ContainerStatus containerStatus) {
      invocations++;
      real.containerCompleted(taskLastAllocated, containerStatus);
    }

    @Override
    public void containerBeingReleased(ContainerId containerId) {
      invocations++;
      real.containerBeingReleased(containerId);
    }

    @Override
    public void nodesUpdated(List<NodeReport> updatedNodes) {
      invocations++;
      real.nodesUpdated(updatedNodes);
    }

    @Override
    public void appShutdownRequested() {
      invocations++;
      real.appShutdownRequested();
    }

    @Override
    public void setApplicationRegistrationData(Resource maxContainerCapability,
        Map<ApplicationAccessType, String> appAcls, ByteBuffer key) {
      invocations++;
      real.setApplicationRegistrationData(maxContainerCapability, appAcls, key);
    }

    @Override
    public void onError(Throwable t) {
      invocations++;
      real.onError(t);
    }

    @Override
    public float getProgress() {
      invocations++;
      return real.getProgress();
    }

    @Override
    public AppFinalStatus getFinalAppStatus() {
      invocations++;
      return real.getFinalAppStatus();
    }

    @Override
    public void preemptContainer(ContainerId cId) {
      invocations++;
      real.preemptContainer(cId);
    }

    public void drain() throws InterruptedException, ExecutionException {
      while (completedEvents < invocations) {
        Future f = countingExecutorService.completionService.poll(5000l, TimeUnit.MILLISECONDS);
        if (f != null) {
          completedEvents++;
        } else {
          fail("Timed out while trying to drain queue");
        }

      }
    }
  }

  static class AlwaysMatchesContainerMatcher implements ContainerSignatureMatcher {

    @Override
    public boolean isSuperSet(Object cs1, Object cs2) {
      Preconditions.checkNotNull(cs1, "Arguments cannot be null");
      Preconditions.checkNotNull(cs2, "Arguments cannot be null");
      return true;
    }

    @Override
    public boolean isExactMatch(Object cs1, Object cs2) {
      return true;
    }

    @Override
    public Map<String, LocalResource> getAdditionalResources(Map<String, LocalResource> lr1,
        Map<String, LocalResource> lr2) {
      return Maps.newHashMap();
    }
  }
  
  static class PreemptionMatcher implements ContainerSignatureMatcher {
    @Override
    public boolean isSuperSet(Object cs1, Object cs2) {
      Preconditions.checkNotNull(cs1, "Arguments cannot be null");
      Preconditions.checkNotNull(cs2, "Arguments cannot be null");
      return true;
    }

    @Override
    public boolean isExactMatch(Object cs1, Object cs2) {
      if (cs1 == cs2 && cs1 != null) {
        return true;
      }
      return false;
    }

    @Override
    public Map<String, LocalResource> getAdditionalResources(Map<String, LocalResource> lr1,
        Map<String, LocalResource> lr2) {
      return Maps.newHashMap();
    }
  }
  

  static void waitForDelayedDrainNotify(AtomicBoolean drainNotifier)
      throws InterruptedException {
    synchronized (drainNotifier) {
      while (!drainNotifier.get()) {
        drainNotifier.wait();
      }
    }
  }
  
  @SuppressWarnings({"rawtypes", "unchecked"})
  private static class CountingExecutorService implements ExecutorService {

    final ExecutorService real;
    final CompletionService completionService;

    CountingExecutorService(ExecutorService real) {
      this.real = real;
      completionService = new ExecutorCompletionService(real);
    }

    @Override
    public void execute(Runnable command) {
      throw new UnsupportedOperationException("Not expected to be used");
    }

    @Override
    public void shutdown() {
      real.shutdown();
    }

    @Override
    public List<Runnable> shutdownNow() {
      return real.shutdownNow();
    }

    @Override
    public boolean isShutdown() {
      return real.isShutdown();
    }

    @Override
    public boolean isTerminated() {
      return real.isTerminated();
    }

    @Override
    public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
      return real.awaitTermination(timeout, unit);
    }

    @Override
    public <T> Future<T> submit(Callable<T> task) {
      return completionService.submit(task);
    }

    @Override
    public <T> Future<T> submit(Runnable task, T result) {
      return completionService.submit(task, result);
    }

    @Override
    public Future<?> submit(Runnable task) {
      throw new UnsupportedOperationException("Not expected to be used");
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)
        throws InterruptedException {
      throw new UnsupportedOperationException("Not expected to be used");
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout,
        TimeUnit unit) throws InterruptedException {
      throw new UnsupportedOperationException("Not expected to be used");
    }

    @Override
    public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException,
        ExecutionException {
      throw new UnsupportedOperationException("Not expected to be used");
    }

    @Override
    public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
        throws InterruptedException, ExecutionException, TimeoutException {
      throw new UnsupportedOperationException("Not expected to be used");
    }
    
  }

}