/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.dag.app.rm;

import com.google.common.annotations.VisibleForTesting;
import org.apache.tez.common.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.commons.lang.mutable.MutableInt;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.ResourceRequest;
import org.apache.hadoop.yarn.client.api.AMRMClient;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
import org.apache.hadoop.yarn.client.api.async.impl.AMRMClientAsyncImpl;
import org.apache.hadoop.yarn.client.api.impl.AMRMClientImpl;
import org.apache.hadoop.yarn.proto.YarnServiceProtos.SchedulerResourceTypes;
import org.apache.hadoop.yarn.util.RackResolver;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.apache.tez.common.ContainerSignatureMatcher;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.app.dag.TaskAttempt;
import org.apache.tez.serviceplugins.api.DagInfo;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskScheduler;
import org.apache.tez.serviceplugins.api.TaskSchedulerContext;
import org.apache.tez.serviceplugins.api.TaskSchedulerContext.AMState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.NavigableMap;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * A YARN task scheduler that is aware of the dependencies between vertices
 * in the DAG and takes them into account when deciding how to schedule
 * and preempt tasks.
 *
 * This scheduler makes the assumption that vertex IDs start at 0 and are
 * densely allocated (i.e.: there are no "gaps" in the vertex ID space).
  */
public class DagAwareYarnTaskScheduler extends TaskScheduler
    implements AMRMClientAsync.CallbackHandler {
  private static final Logger LOG = LoggerFactory.getLogger(DagAwareYarnTaskScheduler.class);
  private static final Comparator<HeldContainer> PREEMPT_ORDER_COMPARATOR = new PreemptOrderComparator();

  private final RandomDataGenerator random = new RandomDataGenerator();
  private AMRMClientAsyncWrapper client;
  private ScheduledExecutorService reuseExecutor;
  private ResourceCalculator resourceCalculator;
  private int numHeartbeats = 0;
  private Resource totalResources = Resource.newInstance(0, 0);
  @GuardedBy("this")
  private Resource allocatedResources = Resource.newInstance(0, 0);
  private final Set<NodeId> blacklistedNodes = Collections.newSetFromMap(new ConcurrentHashMap<NodeId, Boolean>());
  private final ContainerSignatureMatcher signatureMatcher;
  @GuardedBy("this")
  private final RequestTracker requestTracker = new RequestTracker();
  @GuardedBy("this")
  private final Map<ContainerId, HeldContainer> heldContainers = new HashMap<>();
  @GuardedBy("this")
  private final IdleContainerTracker idleTracker = new IdleContainerTracker();
  @GuardedBy("this")
  private final Map<Object, HeldContainer> taskAssignments = new HashMap<>();

  /** A mapping from the vertex ID to the set of containers assigned to tasks for that vertex */
  @GuardedBy("this")
  private final Map<Integer, Set<HeldContainer>> vertexAssignments = new HashMap<>();

  /** If vertex N has at least one task assigned to a container then the corresponding bit at index N is set */
  @GuardedBy("this")
  private final BitSet assignedVertices = new BitSet();

  /**
   * Tracks assigned tasks for released containers so the app can be notified properly when the
   * container completion event finally arrives.
   */
  @GuardedBy("this")
  private final Map<ContainerId, Object> releasedContainers = new HashMap<>();

  @GuardedBy("this")
  private final Set<HeldContainer> sessionContainers = new HashSet<>();

  /**
   * Tracks the set of descendant vertices in the DAG for each vertex.  The BitSet for descendants of vertex N
   * are at array index N.  If a bit is set at index X in the descendants BitSet then vertex X is a descendant
   * of vertex N in the DAG.
   */
  @GuardedBy("this")
  private ArrayList<BitSet> vertexDescendants = null;

  private volatile boolean stopRequested = false;
  private volatile boolean shouldUnregister = false;
  private volatile boolean hasUnregistered = false;

  // cached configuration parameters
  private boolean shouldReuseContainers;
  private boolean reuseRackLocal;
  private boolean reuseNonLocal;
  private boolean reuseNewContainers;
  private long localitySchedulingDelay;
  private long idleContainerTimeoutMin;
  private long idleContainerTimeoutMax;
  private int sessionNumMinHeldContainers;
  private int preemptionPercentage;
  private int numHeartbeatsBetweenPreemptions;
  private int lastPreemptionHeartbeat = 0;
  private long preemptionMaxWaitTime;

  public DagAwareYarnTaskScheduler(TaskSchedulerContext taskSchedulerContext) {
    super(taskSchedulerContext);
    signatureMatcher = taskSchedulerContext.getContainerSignatureMatcher();
  }

  @Override
  public void initialize() throws Exception {
    initialize(new AMRMClientAsyncWrapper(new AMRMClientImpl<TaskRequest>(), 1000, this));
  }

  void initialize(AMRMClientAsyncWrapper client) throws Exception {
    super.initialize();
    this.client = client;
    Configuration conf = TezUtils.createConfFromUserPayload(getContext().getInitialUserPayload());
    client.init(conf);

    int heartbeatIntervalMax = conf.getInt(
        TezConfiguration.TEZ_AM_RM_HEARTBEAT_INTERVAL_MS_MAX,
        TezConfiguration.TEZ_AM_RM_HEARTBEAT_INTERVAL_MS_MAX_DEFAULT);
    client.setHeartbeatInterval(heartbeatIntervalMax);

    shouldReuseContainers = conf.getBoolean(
        TezConfiguration.TEZ_AM_CONTAINER_REUSE_ENABLED,
        TezConfiguration.TEZ_AM_CONTAINER_REUSE_ENABLED_DEFAULT);
    reuseRackLocal = conf.getBoolean(
        TezConfiguration.TEZ_AM_CONTAINER_REUSE_RACK_FALLBACK_ENABLED,
        TezConfiguration.TEZ_AM_CONTAINER_REUSE_RACK_FALLBACK_ENABLED_DEFAULT);
    reuseNonLocal = conf
      .getBoolean(
        TezConfiguration.TEZ_AM_CONTAINER_REUSE_NON_LOCAL_FALLBACK_ENABLED,
        TezConfiguration.TEZ_AM_CONTAINER_REUSE_NON_LOCAL_FALLBACK_ENABLED_DEFAULT);
    Preconditions.checkArgument(
      ((!reuseRackLocal && !reuseNonLocal) || (reuseRackLocal)),
      "Re-use Rack-Local cannot be disabled if Re-use Non-Local has been"
      + " enabled");

    reuseNewContainers = shouldReuseContainers && conf.getBoolean(
        TezConfiguration.TEZ_AM_CONTAINER_REUSE_NEW_CONTAINERS_ENABLED,
        TezConfiguration.TEZ_AM_CONTAINER_REUSE_NEW_CONTAINERS_ENABLED_DEFAULT);

    localitySchedulingDelay = conf.getLong(
      TezConfiguration.TEZ_AM_CONTAINER_REUSE_LOCALITY_DELAY_ALLOCATION_MILLIS,
      TezConfiguration.TEZ_AM_CONTAINER_REUSE_LOCALITY_DELAY_ALLOCATION_MILLIS_DEFAULT);
    Preconditions.checkArgument(localitySchedulingDelay >= 0,
        "Locality Scheduling delay should be >=0");

    idleContainerTimeoutMin = conf.getLong(
        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MIN_MILLIS,
        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MIN_MILLIS_DEFAULT);
    Preconditions.checkArgument(idleContainerTimeoutMin >= 0 || idleContainerTimeoutMin == -1,
      "Idle container release min timeout should be either -1 or >=0");

    idleContainerTimeoutMax = conf.getLong(
        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MAX_MILLIS,
        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MAX_MILLIS_DEFAULT);
    Preconditions.checkArgument(
        idleContainerTimeoutMax >= 0 && idleContainerTimeoutMax >= idleContainerTimeoutMin,
        "Idle container release max timeout should be >=0 and >= " +
        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MIN_MILLIS);

    sessionNumMinHeldContainers = conf.getInt(TezConfiguration.TEZ_AM_SESSION_MIN_HELD_CONTAINERS,
        TezConfiguration.TEZ_AM_SESSION_MIN_HELD_CONTAINERS_DEFAULT);
    Preconditions.checkArgument(sessionNumMinHeldContainers >= 0,
        "Session minimum held containers should be >=0");

    preemptionPercentage = conf.getInt(TezConfiguration.TEZ_AM_PREEMPTION_PERCENTAGE,
        TezConfiguration.TEZ_AM_PREEMPTION_PERCENTAGE_DEFAULT);
    Preconditions.checkArgument(preemptionPercentage >= 0 && preemptionPercentage <= 100,
        "Preemption percentage should be between 0-100");

    numHeartbeatsBetweenPreemptions = conf.getInt(
        TezConfiguration.TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS,
        TezConfiguration.TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS_DEFAULT);
    Preconditions.checkArgument(numHeartbeatsBetweenPreemptions >= 1,
        "Heartbeats between preemptions should be >=1");

    preemptionMaxWaitTime = conf.getInt(TezConfiguration.TEZ_AM_PREEMPTION_MAX_WAIT_TIME_MS,
        TezConfiguration.TEZ_AM_PREEMPTION_MAX_WAIT_TIME_MS_DEFAULT);
    Preconditions.checkArgument(preemptionMaxWaitTime >=0, "Preemption max wait time must be >=0");

    LOG.info("scheduler initialized with maxRMHeartbeatInterval:" + heartbeatIntervalMax +
            " reuseEnabled:" + shouldReuseContainers +
            " reuseRack:" + reuseRackLocal +
            " reuseAny:" + reuseNonLocal +
            " localityDelay:" + localitySchedulingDelay +
            " preemptPercentage:" + preemptionPercentage +
            " preemptMaxWaitTime:" + preemptionMaxWaitTime +
            " numHeartbeatsBetweenPreemptions:" + numHeartbeatsBetweenPreemptions +
            " idleContainerMinTimeout:" + idleContainerTimeoutMin +
            " idleContainerMaxTimeout:" + idleContainerTimeoutMax +
            " sessionMinHeldContainers:" + sessionNumMinHeldContainers);
  }

  @Override
  public void start() throws Exception {
    super.start();
    client.start();
    if (shouldReuseContainers) {
      reuseExecutor = createExecutor();
    }
    TaskSchedulerContext ctx = getContext();
    RegisterApplicationMasterResponse response = client.registerApplicationMaster(
        ctx.getAppHostName(), ctx.getAppClientPort(), ctx.getAppTrackingUrl());
    ctx.setApplicationRegistrationData(response.getMaximumResourceCapability(),
        response.getApplicationACLs(), response.getClientToAMTokenMasterKey(),
        response.getQueue());
    if (response.getSchedulerResourceTypes().contains(SchedulerResourceTypes.CPU)) {
      resourceCalculator = new MemCpuResourceCalculator();
    } else {
      resourceCalculator = new MemResourceCalculator();
    }
  }

  protected ScheduledExecutorService createExecutor() {
    return new ReuseContainerExecutor();
  }

  protected long now() {
    return Time.monotonicNow();
  }

  @Override
  public void initiateStop() {
    super.initiateStop();
    LOG.debug("Initiating stop of task scheduler");
    stopRequested = true;
    List<ContainerId> releasedLaunchedContainers;
    synchronized (this) {
      releasedLaunchedContainers = new ArrayList<>(heldContainers.size());
      List<HeldContainer> heldList = new ArrayList<>(heldContainers.values());
      for (HeldContainer hc : heldList) {
        if (releaseContainer(hc)) {
          releasedLaunchedContainers.add(hc.getId());
        }
      }

      List<Object> tasks = requestTracker.getTasks();
      for (Object task : tasks) {
        removeTaskRequest(task);
      }
    }

    // perform app callback outside of locks
    for (ContainerId id : releasedLaunchedContainers) {
      getContext().containerBeingReleased(id);
    }
  }

  @Override
  public void shutdown() throws Exception {
    super.shutdown();
    if (reuseExecutor != null) {
      reuseExecutor.shutdown();
      reuseExecutor.awaitTermination(2, TimeUnit.SECONDS);
    }
    synchronized (this) {
      if (shouldUnregister && !hasUnregistered) {
          TaskSchedulerContext.AppFinalStatus status = getContext().getFinalAppStatus();
          LOG.info("Unregistering from RM, exitStatus={} exitMessage={} trackingURL={}",
              status.exitStatus, status.exitMessage, status.postCompletionTrackingUrl);
          client.unregisterApplicationMaster(status.exitStatus,
              status.exitMessage,
              status.postCompletionTrackingUrl);
          hasUnregistered = true;
      }
    }
    client.stop();
  }

  @Override
  public void onContainersAllocated(List<Container> containers) {
    AMState appState = getContext().getAMState();
    if (stopRequested || appState == AMState.COMPLETED) {
      LOG.info("Ignoring {} allocations since app is terminating", containers.size());
      for (Container c : containers) {
        client.releaseAssignedContainer(c.getId());
      }
      return;
    }
    List<Assignment> assignments = assignNewContainers(containers, getContext().getAMState(), getContext().isSession());
    informAppAboutAssignments(assignments);
  }

  private synchronized List<Assignment> assignNewContainers(List<Container> newContainers,
      AMState appState, boolean isSession) {
    // try to assign the containers as node-local
    List<Assignment> assignments = new ArrayList<>(newContainers.size());
    List<HeldContainer> unassigned = new ArrayList<>(newContainers.size());
    for (Container c : newContainers) {
      HeldContainer hc = new HeldContainer(c);
      heldContainers.put(hc.getId(), hc);
      Resources.addTo(allocatedResources, c.getResource());
      tryAssignNewContainer(hc, hc.getHost(), assignments, unassigned);
    }

    // try to assign the remaining containers as rack-local
    List<HeldContainer> containers = unassigned;
    unassigned = new ArrayList<>(containers.size());
    for (HeldContainer hc : containers) {
      tryAssignNewContainer(hc, hc.getRack(), assignments, unassigned);
    }

    // try to assign the remaining containers without locality
    containers = unassigned;
    unassigned = new ArrayList<>(containers.size());
    for (HeldContainer hc : containers) {
      tryAssignNewContainer(hc, ResourceRequest.ANY, assignments, unassigned);
    }

    for (HeldContainer hc : unassigned) {
      if (reuseNewContainers) {
        idleTracker.add(hc);
        TaskRequest assigned = tryAssignReuseContainer(hc, appState, isSession);
        if (assigned != null) {
          assignments.add(new Assignment(assigned, hc.getContainer()));
        }
      } else {
        releaseContainer(hc);
      }
    }

    return assignments;
  }

  /**
   * Try to assign a newly acquired container to a task of the same priority.
   *
   * @param hc the container to assign
   * @param location the locality to consider for assignment
   * @param assignments list to update if container is assigned
   * @param unassigned list to update if container is not assigned
   */
  @GuardedBy("this")
  private void tryAssignNewContainer(HeldContainer hc, String location,
      List<Assignment> assignments, List<HeldContainer> unassigned) {
    List<? extends Collection<TaskRequest>> results = client.getMatchingRequests(hc.getPriority(),
        location, hc.getCapability());
    if (!results.isEmpty()) {
      for (Collection<TaskRequest> requests : results) {
        if (!requests.isEmpty()) {
          TaskRequest request = requests.iterator().next();
          if (maybeChangeNode(request, hc.getContainer().getNodeId())) {
            continue;
          }
          assignContainer(request, hc, location);
          assignments.add(new Assignment(request, hc.getContainer()));
          return;
        }
      }
    }

    unassigned.add(hc);
  }

  @GuardedBy("this")
  @Nullable
  private TaskRequest tryAssignReuseContainer(HeldContainer hc,
      AMState appState, boolean isSession) {
    if (stopRequested) {
      return null;
    }

    TaskRequest assignedRequest = null;
    switch (appState) {
    case IDLE:
      handleReuseContainerWhenIdle(hc, isSession);
      break;
    case RUNNING_APP:
      if (requestTracker.isEmpty()) {
        // treat no requests as if app is idle
        handleReuseContainerWhenIdle(hc, isSession);
      } else {
        assignedRequest = tryAssignReuseContainerAppRunning(hc);
        if (assignedRequest == null) {
          if (hc.atMaxMatchLevel()) {
            LOG.info("Releasing idle container {} due to pending requests", hc.getId());
            releaseContainer(hc);
          } else {
            hc.scheduleForReuse(localitySchedulingDelay);
          }
        }
      }
      break;
    case COMPLETED:
      LOG.info("Releasing container {} because app has completed", hc.getId());
      releaseContainer(hc);
      break;
    default:
      throw new IllegalStateException("Unexpected app state " + appState);
    }

    return assignedRequest;
  }

  @GuardedBy("this")
  private void handleReuseContainerWhenIdle(HeldContainer hc, boolean isSession) {
    if (isSession && sessionContainers.isEmpty() && sessionNumMinHeldContainers > 0) {
      computeSessionContainers();
    }

    if (sessionContainers.contains(hc)) {
      LOG.info("Retaining container {} since it is a session container");
      hc.resetMatchingLevel();
    } else {
      long now = now();
      long expiration = hc.getIdleExpirationTimestamp(now);
      if (now >= expiration) {
        LOG.info("Releasing expired idle container {}", hc.getId());
        releaseContainer(hc);
      } else {
        hc.scheduleForReuse(expiration - now);
      }
    }
  }

  @GuardedBy("this")
  @Nullable
  private TaskRequest tryAssignReuseContainerAppRunning(HeldContainer hc) {
    if (!hc.isAssignable()) {
      LOG.debug("Skipping scheduling of container {} because it state is {}", hc.getId(), hc.getState());
      return null;
    }

    TaskRequest assignedRequest = tryAssignReuseContainerForAffinity(hc);
    if (assignedRequest != null) {
      return assignedRequest;
    }

    for (Entry<Priority,RequestPriorityStats> entry : requestTracker.getStatsEntries()) {
      Priority priority = entry.getKey();
      RequestPriorityStats stats = entry.getValue();
      if (!stats.allowedVertices.intersects(stats.vertices)) {
        LOG.debug("Skipping requests at priority {} because all requesting vertices are blocked by higher priority requests",
            priority);
        continue;
      }

      String matchLocation = hc.getMatchingLocation();
      if (stats.localityCount <= 0) {
        LOG.debug("Overriding locality match of container {} to ANY since there are no locality requests at priority {}",
            hc.getId(), priority);
        matchLocation = ResourceRequest.ANY;
      }
      assignedRequest = tryAssignReuseContainerForPriority(hc, matchLocation,
          priority, stats.allowedVertices);
      if (assignedRequest != null) {
        break;
      }
    }
    return assignedRequest;
  }

  @GuardedBy("this")
  @Nullable
  private TaskRequest tryAssignReuseContainerForAffinity(HeldContainer hc) {
    Collection<TaskRequest> affinities = hc.getAffinities();
    if (affinities != null) {
      for (TaskRequest request : affinities) {
        if (requestTracker.isRequestBlocked(request)) {
          LOG.debug("Cannot assign task {} to container {} since vertex {} is a descendant of pending tasks",
              request.getTask(), hc.getId(), request.getVertexIndex());
        } else if (maybeChangeNode(request, hc.getContainer().getNodeId())) {
          LOG.debug("Cannot assign task {} to container {} since node {} is running sibling attempts",
              request.getTask(), hc.getId(), request.getVertexIndex());
        } else {
          assignContainer(request, hc, hc.getId());
          return request;
        }
      }
    }
    return null;
  }

  @GuardedBy("this")
  @Nullable
  private TaskRequest tryAssignReuseContainerForPriority(HeldContainer hc, String matchLocation,
      Priority priority, BitSet allowedVertices) {
    List<? extends Collection<TaskRequest>> results = client.getMatchingRequests(priority, matchLocation, hc.getCapability());
    if (results.isEmpty()) {
      return null;
    }

    for (Collection<TaskRequest> requests : results) {
      for (TaskRequest request : requests) {
        final int vertexIndex = request.getVertexIndex();
        if (!allowedVertices.get(vertexIndex)) {
          LOG.debug("Not assigning task {} since it is a descendant of a pending vertex", request.getTask());
          continue;
        }

        Object signature = hc.getSignature();
        if (signature == null || signatureMatcher.isSuperSet(signature, request.getContainerSignature())) {
          if (!maybeChangeNode(request, hc.getContainer().getNodeId())) {
            assignContainer(request, hc, matchLocation);
            return request;
          }
        }
      }
    }
    return null;
  }

  private void informAppAboutAssignments(List<Assignment> assignments) {
    if (!assignments.isEmpty()) {
      for (Assignment a : assignments) {
        informAppAboutAssignment(a.request, a.container);
      }
    }
  }

  /**
   * Inform the app about a task assignment.  This should not be called with
   * any locks held.
   *
   * @param request the corresponding task request
   * @param container the container assigned to the task
   */
  private void informAppAboutAssignment(TaskRequest request, Container container) {
    if (blacklistedNodes.contains(container.getNodeId())) {
      Object task = request.getTask();
      LOG.info("Container {} allocated for task {} on blacklisted node {}",
          container.getId(), container.getNodeId(), task);
      deallocateContainer(container.getId());
      // its ok to submit the same request again because the RM will not give us
      // the bad/unhealthy nodes again. The nodes may become healthy/unblacklisted
      // and so its better to give the RM the full information.
      allocateTask(task, request.getCapability(),
          (request.getNodes() == null ? null :
              request.getNodes().toArray(new String[request.getNodes().size()])),
          (request.getRacks() == null ? null :
              request.getRacks().toArray(new String[request.getRacks().size()])),
          request.getPriority(),
          request.getContainerSignature(),
          request.getCookie());
    } else {
      getContext().taskAllocated(request.getTask(), request.getCookie(), container);
    }
  }

  @GuardedBy("this")
  private void computeSessionContainers() {
    Map<String, MutableInt> rackHeldNumber = new HashMap<>();
    Map<String, List<HeldContainer>> nodeHeldContainers = new HashMap<>();
    for(HeldContainer heldContainer : heldContainers.values()) {
      if (heldContainer.getSignature() == null) {
        // skip containers that have not been launched as there is no process to reuse
        continue;
      }
      MutableInt count = rackHeldNumber.get(heldContainer.getRack());
      if (count == null) {
        count = new MutableInt(0);
        rackHeldNumber.put(heldContainer.getRack(), count);
      }
      count.increment();
      String host = heldContainer.getHost();
      List<HeldContainer> nodeContainers = nodeHeldContainers.get(host);
      if (nodeContainers == null) {
        nodeContainers = new LinkedList<>();
        nodeHeldContainers.put(host, nodeContainers);
      }
      nodeContainers.add(heldContainer);
    }

    Map<String, MutableInt> rackToHoldNumber = new HashMap<>();
    for (String rack : rackHeldNumber.keySet()) {
      rackToHoldNumber.put(rack, new MutableInt(0));
    }

    // distribute evenly across nodes
    // the loop assigns 1 container per rack over all racks
    int containerCount = 0;
    while (containerCount < sessionNumMinHeldContainers && !rackHeldNumber.isEmpty()) {
      Iterator<Entry<String, MutableInt>> iter = rackHeldNumber.entrySet().iterator();
      while (containerCount < sessionNumMinHeldContainers && iter.hasNext()) {
        Entry<String, MutableInt> entry = iter.next();
        MutableInt rackCount = entry.getValue();
        rackCount.decrement();
        if (rackCount.intValue() >=0) {
          containerCount++;
          rackToHoldNumber.get(entry.getKey()).increment();
        } else {
          iter.remove();
        }
      }
    }

    // distribute containers evenly across nodes while not exceeding rack limit
    // the loop assigns 1 container per node over all nodes
    containerCount = 0;
    while (containerCount < sessionNumMinHeldContainers && !nodeHeldContainers.isEmpty()) {
      Iterator<Entry<String, List<HeldContainer>>> iter = nodeHeldContainers.entrySet().iterator();
      while (containerCount < sessionNumMinHeldContainers && iter.hasNext()) {
        List<HeldContainer> nodeContainers = iter.next().getValue();
        if (nodeContainers.isEmpty()) {
          // node is empty. remove it.
          iter.remove();
          continue;
        }
        HeldContainer heldContainer = nodeContainers.remove(nodeContainers.size() - 1);
        MutableInt holdCount = rackToHoldNumber.get(heldContainer.getRack());
        holdCount.decrement();
        if (holdCount.intValue() >= 0) {
          // rack can hold a container
          containerCount++;
          sessionContainers.add(heldContainer);
        } else {
          // rack limit reached. remove node.
          iter.remove();
        }
      }
    }

    LOG.info("Identified {} session containers out of {} total containers",
        sessionContainers.size(), heldContainers.size());
  }

  @GuardedBy("this")
  private void activateSessionContainers() {
    if (!sessionContainers.isEmpty()) {
      for (HeldContainer hc : sessionContainers) {
        if (hc.isAssignable()) {
          hc.scheduleForReuse(localitySchedulingDelay);
        }
      }
      sessionContainers.clear();
    }
  }

  @Override
  public void onContainersCompleted(List<ContainerStatus> statuses) {
    if (stopRequested) {
      return;
    }

    List<TaskStatus> taskStatusList = new ArrayList<>(statuses.size());
    synchronized (this) {
      for (ContainerStatus status : statuses) {
        ContainerId cid = status.getContainerId();
        LOG.info("Container {} completed with status {}", cid, status);
        Object task = releasedContainers.remove(cid);
        if (task == null) {
          HeldContainer hc = heldContainers.get(cid);
          if (hc != null) {
            task = containerCompleted(hc);
          }
        }
        if (task != null) {
          taskStatusList.add(new TaskStatus(task, status));
        }
      }
    }

    // perform app callback outside of locks
    for (TaskStatus taskStatus : taskStatusList) {
      getContext().containerCompleted(taskStatus.task, taskStatus.status);
    }
  }

  @Override
  public void onNodesUpdated(List<NodeReport> updatedNodes) {
    if (!stopRequested) {
      getContext().nodesUpdated(updatedNodes);
    }
  }

  @Override
  public float getProgress() {
    if (stopRequested) {
      return 1;
    }

    Collection<ContainerId> preemptedContainers;
    synchronized (this) {
      Resource freeResources = getAvailableResources();
      if (totalResources.getMemory() == 0) {
        // assume this is the first allocate callback. nothing is allocated.
        // available resource = totalResource
        // TODO this will not handle dynamic changes in resources
        totalResources = Resources.clone(freeResources);
        LOG.info("App total resource memory: {} cpu: {} activeAssignments: {}",
            totalResources.getMemory(), totalResources.getVirtualCores(), taskAssignments.size());
      }

      ++numHeartbeats;
      if (LOG.isDebugEnabled() || numHeartbeats % 50 == 1) {
        LOG.info(constructPeriodicLog(freeResources));
      }

      preemptedContainers = maybePreempt(freeResources);
      if (preemptedContainers != null && !preemptedContainers.isEmpty()) {
        lastPreemptionHeartbeat = numHeartbeats;
      }
    }

    // perform app callback outside of locks
    if (preemptedContainers != null && !preemptedContainers.isEmpty()) {
      for (ContainerId cid : preemptedContainers) {
        LOG.info("Preempting container {} currently allocated to a task", cid);
        getContext().preemptContainer(cid);
      }
    }

    return getContext().getProgress();
  }

  @Override
  public void onShutdownRequest() {
    if (!stopRequested) {
      getContext().appShutdownRequested();
    }
  }

  @Override
  public void onError(Throwable e) {
    LOG.error("Error from ARMRMClient", e);
    if (!stopRequested) {
      getContext().reportError(YarnTaskSchedulerServiceError.RESOURCEMANAGER_ERROR,
          StringUtils.stringifyException(e), null);
    }
  }

  @Override
  public Resource getAvailableResources() {
    return client.getAvailableResources();
  }

  @Override
  public Resource getTotalResources() {
    return totalResources;
  }

  @Override
  public int getClusterNodeCount() {
    return client.getClusterNodeCount();
  }

  @Override
  public synchronized void blacklistNode(NodeId nodeId) {
    LOG.info("Blacklisting node: {}", nodeId);
    blacklistedNodes.add(nodeId);
    client.updateBlacklist(Collections.singletonList(nodeId.getHost()), null);
  }

  @Override
  public synchronized void unblacklistNode(NodeId nodeId) {
    if (blacklistedNodes.remove(nodeId)) {
      LOG.info("Removing blacklist for node: {}", nodeId);
      client.updateBlacklist(null, Collections.singletonList(nodeId.getHost()));
    }
  }

  @Override
  public void allocateTask(Object task, Resource capability, String[] hosts, String[] racks,
      Priority priority, Object containerSignature, Object clientCookie) {
    int vertexIndex = getContext().getVertexIndexForTask(task);
    TaskRequest request = new TaskRequest(task, vertexIndex, capability, hosts, racks,
        priority, containerSignature, clientCookie);
    addTaskRequest(request);
  }

  @Override
  public void allocateTask(Object task, Resource capability, ContainerId containerId,
      Priority priority, Object containerSignature, Object clientCookie) {
    String[] hosts = null;
    synchronized (this) {
      HeldContainer held = heldContainers.get(containerId);
      if (held != null) {
        if (held.canFit(capability)) {
          hosts = new String[]{held.getHost()};
        } else {
          LOG.warn("Match request to container {} but {} does not fit in {}",
              containerId, capability, held.getCapability());
          containerId = null;
        }
      } else {
        LOG.info("Ignoring match request to unknown container {}", containerId);
        containerId = null;
      }
    }
    int vertexIndex = getContext().getVertexIndexForTask(task);
    TaskRequest request = new TaskRequest(task, vertexIndex, capability, hosts, null,
        priority, containerSignature, clientCookie, containerId);
    addTaskRequest(request);
  }

  @Override
  public boolean deallocateTask(Object task, boolean taskSucceeded,
      TaskAttemptEndReason endReason, String diagnostics) {
    ContainerId releasedLaunchedContainer = null;
    AMState appState = getContext().getAMState();
    boolean isSession = getContext().isSession();
    TaskRequest newAssignment = null;
    HeldContainer hc;
    synchronized (this) {
      TaskRequest request = removeTaskRequest(task);
      if (request != null) {
        LOG.debug("Deallocating task {} before it was allocated", task);
        return false;
      }

      hc = removeTaskAssignment(task);
      if (hc != null) {
        if (taskSucceeded && shouldReuseContainers) {
          idleTracker.add(hc);
          newAssignment = tryAssignReuseContainer(hc, appState, isSession);
          if (newAssignment == null && hc.isReleasedAndUsed()) {
            releasedLaunchedContainer = hc.getId();
          }
        } else {
          if (releaseContainer(hc)) {
            releasedLaunchedContainer = hc.getId();
          }
        }
      }
    }

    // perform app callback outside of locks
    if (newAssignment != null) {
      informAppAboutAssignment(newAssignment, hc.getContainer());
      return true;
    }
    if (releasedLaunchedContainer != null) {
      getContext().containerBeingReleased(releasedLaunchedContainer);
      return true;
    }
    return hc != null;
  }

  @Override
  public Object deallocateContainer(ContainerId containerId) {
    Object task = null;
    ContainerId releasedLaunchedContainer = null;
    synchronized (this) {
      HeldContainer hc = heldContainers.remove(containerId);
      if (hc != null) {
        task = hc.getAssignedTask();
        if (task != null) {
          LOG.info("Deallocated container {} from task {}", containerId, task);
        }
        if (releaseContainer(hc)) {
          releasedLaunchedContainer = hc.getId();
        }
      } else {
        LOG.info("Ignoring deallocation of unknown container {}", containerId);
      }
    }

    // perform app callback outside of locks
    if (releasedLaunchedContainer != null) {
      getContext().containerBeingReleased(releasedLaunchedContainer);
    }
    return task;
  }

  @GuardedBy("this")
  private void assignContainer(TaskRequest request, HeldContainer hc, Object match) {
    LOG.info("Assigning container {} to task {} host={} priority={} capability={} match={} lastTask={}",
        hc.getId(), request.getTask(), hc.getHost(), hc.getPriority(), hc.getCapability(), match, hc.getLastTask());
    removeTaskRequest(request.getTask());
    addTaskAssignment(request, hc);
    idleTracker.remove(hc);
  }

  private synchronized boolean releaseContainer(HeldContainer hc) {
    Object task = containerCompleted(hc);
    client.releaseAssignedContainer(hc.getId());
    if (task != null) {
      releasedContainers.put(hc.getId(), task);
      return true;
    }
    return false;
  }

  @GuardedBy("this")
  private void addTaskAssignment(TaskRequest request, HeldContainer hc) {
    HeldContainer oldContainer = taskAssignments.put(request.getTask(), hc);
    if (oldContainer != null) {
      LOG.error("Task {} being assigned to container {} but was already assigned to container {}",
          request.getTask(), hc.getId(), oldContainer.getId());
    }
    Integer vertexIndex = request.vertexIndex;
    Set<HeldContainer> cset = vertexAssignments.get(vertexIndex);
    if (cset == null) {
      cset = new HashSet<>();
      vertexAssignments.put(vertexIndex, cset);
      assignedVertices.set(vertexIndex);
    }
    cset.add(hc);
    hc.assignTask(request);
  }

  @GuardedBy("this")
  private HeldContainer removeTaskAssignment(Object task) {
    HeldContainer hc = taskAssignments.remove(task);
    if (hc != null) {
      TaskRequest request = hc.removeAssignment();
      if (request != null) {
        Integer vertexIndex = request.vertexIndex;
        Set<HeldContainer> cset = vertexAssignments.get(vertexIndex);
        if (cset != null && cset.remove(hc) && cset.isEmpty()) {
          vertexAssignments.remove(vertexIndex);
          assignedVertices.clear(vertexIndex);
        }
      } else {
        LOG.error("Container {} had assigned task {} but no request?!?", hc.getId(), task);
      }
    }
    return hc;
  }

  @GuardedBy("this")
  @Nullable
  private Object containerCompleted(HeldContainer hc) {
    idleTracker.remove(hc);
    heldContainers.remove(hc.getId());
    Resources.subtractFrom(allocatedResources, hc.getCapability());
    removeTaskAssignment(hc.getAssignedTask());
    hc.released();
    return hc.getLastTask();
  }

  @GuardedBy("this")
  private void ensureVertexDescendants() {
    if (vertexDescendants == null) {
      DagInfo info = getContext().getCurrentDagInfo();
      if (info == null) {
        throw new IllegalStateException("Scheduling tasks but no current DAG info?");
      }
      int numVertices = info.getTotalVertices();
      ArrayList<BitSet> descendants = new ArrayList<>(numVertices);
      for (int i = 0; i < numVertices; ++i) {
        descendants.add(info.getVertexDescendants(i));
      }
      vertexDescendants = descendants;
    }
  }

  private void addTaskRequest(TaskRequest request) {
    Container assignedContainer = null;
    synchronized (this) {
      if (shouldReuseContainers && !stopRequested && getContext().getAMState() != AMState.COMPLETED) {
        ensureVertexDescendants();
        activateSessionContainers();
        HeldContainer hc = tryAssignTaskToIdleContainer(request);
        if (hc != null) {
          assignedContainer = hc.getContainer();
        }
      }

      if (assignedContainer == null) {
        ensureVertexDescendants();
        TaskRequest old = requestTracker.add(request);
        if (old != null) {
          removeTaskRequestByRequest(request);
        }
        client.addContainerRequest(request);

        HeldContainer hc = heldContainers.get(request.getAffinity());
        if (hc != null) {
          hc.addAffinity(request);
        }
      }
    }

    // perform app callback outside of locks
    if (assignedContainer != null) {
      informAppAboutAssignment(request, assignedContainer);
    }
  }

  @Nullable
  private synchronized TaskRequest removeTaskRequest(Object task) {
    TaskRequest request = requestTracker.remove(task);
    if (request != null) {
      removeTaskRequestByRequest(request);
    }
    return request;
  }

  @GuardedBy("this")
  private void removeTaskRequestByRequest(TaskRequest request) {
    client.removeContainerRequest(request);
    HeldContainer hc = heldContainers.get(request.getAffinity());
    if (hc != null) {
      hc.removeAffinity(request);
    }
  }

  @GuardedBy("this")
  @Nullable
  private HeldContainer tryAssignTaskToIdleContainer(TaskRequest request) {
    if (requestTracker.isRequestBlocked(request)) {
      LOG.debug("Cannot assign task {} to an idle container since vertex {} is a descendant of pending tasks",
          request.getTask(), request.getVertexIndex());
      return null;
    }

    // check if container affinity can be satisfied immediately
    ContainerId affinity = request.getAffinity();
    if (affinity != null) {
      HeldContainer hc = heldContainers.get(affinity);
      if (hc != null && hc.isAssignable() && !maybeChangeNode(request, hc.getContainer().getNodeId())) {
        assignContainer(request, hc, affinity);
        return hc;
      }
    }

    // try to match the task against idle containers in order from best locality to worst
    HeldContainer hc;
    if (request.hasLocality()) {
      hc = tryAssignTaskToIdleContainer(request, request.getNodes(), HeldContainerState.MATCHES_LOCAL_STATES);
      if (hc == null) {
        hc = tryAssignTaskToIdleContainer(request, request.getRacks(), HeldContainerState.MATCHES_RACK_STATES);
        if (hc == null) {
          hc = tryAssignTaskToIdleContainer(request, ResourceRequest.ANY, HeldContainerState.MATCHES_ANY_STATES);
        }
      }
    } else {
      hc = tryAssignTaskToIdleContainer(request, ResourceRequest.ANY, HeldContainerState.MATCHES_LOCAL_STATES);
    }

    return hc;
  }

  @GuardedBy("this")
  @Nullable
  private HeldContainer tryAssignTaskToIdleContainer(TaskRequest request,
      List<String> locations, EnumSet<HeldContainerState> eligibleStates) {
    if (locations != null && !locations.isEmpty()) {
      for (String location : locations) {
        HeldContainer hc = tryAssignTaskToIdleContainer(request, location, eligibleStates);
        if (hc != null) {
          return hc;
        }
      }
    }
    return null;
  }

  @GuardedBy("this")
  @Nullable
  private HeldContainer tryAssignTaskToIdleContainer(TaskRequest request,
      String location, EnumSet<HeldContainerState> eligibleStates) {
    Set<HeldContainer> containers = idleTracker.getByLocation(location);
    HeldContainer bestMatch = null;
    if (containers != null && !containers.isEmpty()) {
      for (HeldContainer hc : containers) {
        if (eligibleStates.contains(hc.getState())) {
          Object csig = hc.getSignature();
          if (csig == null || signatureMatcher.isSuperSet(csig, request.getContainerSignature())) {
            boolean needToChangeNode = maybeChangeNode(request, hc.getContainer().getNodeId());
            int numAffinities = hc.getNumAffinities();
            if (numAffinities == 0 && !needToChangeNode) {
              bestMatch = hc;
              break;
            }
            if ((bestMatch == null || numAffinities < bestMatch.getNumAffinities()) && !needToChangeNode) {
              bestMatch = hc;
            }
          } else {
            LOG.debug("Unable to assign task {} to container {} due to signature mismatch", request.getTask(), hc.getId());
          }
        }
      }
    }
    if (bestMatch != null) {
      assignContainer(request, bestMatch, location);
    }
    return bestMatch;
  }

  private boolean maybeChangeNode(TaskRequest request, NodeId nodeId) {
    Object task = request.getTask();
    if (task instanceof TaskAttempt) {
      Set<NodeId> nodesWithSiblingRunningAttempts = ((TaskAttempt) task).getTask().getNodesWithRunningAttempts();
      if (nodesWithSiblingRunningAttempts != null
          && nodesWithSiblingRunningAttempts.contains(nodeId)) {
        return true;
      }
    }
    return false;
  }

  @Override
  public void setShouldUnregister() {
    shouldUnregister = true;
  }

  @Override
  public boolean hasUnregistered() {
    return hasUnregistered;
  }

  @Override
  public synchronized void dagComplete() {
    for (HeldContainer hc : sessionContainers) {
      hc.resetMatchingLevel();
    }
    vertexDescendants = null;
  }

  @GuardedBy("this")
  @Nullable
  private Collection<ContainerId> maybePreempt(Resource freeResources) {
    if (preemptionPercentage == 0 || numHeartbeats - lastPreemptionHeartbeat < numHeartbeatsBetweenPreemptions) {
      return null;
    }
    if (!requestTracker.isPreemptionDeadlineExpired() && requestTracker.fitsHighestPriorityRequest(freeResources)) {
      if (numHeartbeats % 50 == 1) {
        LOG.info("Highest priority request fits in free resources {}", freeResources);
      }
      return null;
    }

    int numIdleContainers = idleTracker.getNumContainers();
    if (numIdleContainers > 0) {
      if (numHeartbeats % 50 == 1) {
        LOG.info("Avoiding preemption since there are {} idle containers", numIdleContainers);
      }
      return null;
    }

    BitSet blocked = requestTracker.createVertexBlockedSet();
    if (!blocked.intersects(assignedVertices)) {
      if (numHeartbeats % 50 == 1) {
        LOG.info("Avoiding preemption since there are no descendants of the highest priority requests running");
      }
      return null;
    }

    Resource preemptLeft = requestTracker.getAmountToPreempt(preemptionPercentage);
    if (!resourceCalculator.anyAvailable(preemptLeft)) {
      if (numHeartbeats % 50 == 1) {
        LOG.info("Avoiding preemption since amount to preempt is {}", preemptLeft);
      }
      return null;
    }

    PriorityQueue<HeldContainer> candidates = new PriorityQueue<>(11, PREEMPT_ORDER_COMPARATOR);
    blocked.and(assignedVertices);
    for (int i = blocked.nextSetBit(0); i >= 0; i = blocked.nextSetBit(i + 1)) {
      Collection<HeldContainer> containers = vertexAssignments.get(i);
      if (containers != null) {
        candidates.addAll(containers);
      } else {
        LOG.error("Vertex {} in assignedVertices but no assignments?", i);
      }
    }

    ArrayList<ContainerId> preemptedContainers = new ArrayList<>();
    HeldContainer hc;
    while ((hc = candidates.poll()) != null) {
      LOG.info("Preempting container {} currently allocated to task {}", hc.getId(), hc.getAssignedTask());
      preemptedContainers.add(hc.getId());
      resourceCalculator.deductFrom(preemptLeft, hc.getCapability());
      if (!resourceCalculator.anyAvailable(preemptLeft)) {
        break;
      }
    }

    return preemptedContainers;
  }

  @GuardedBy("this")
  private String constructPeriodicLog(Resource freeResource) {
    Priority highestPriority = requestTracker.getHighestPriority();
    return "Allocated: " + allocatedResources +
        " Free: " + freeResource +
        " pendingRequests: " + requestTracker.getNumRequests() +
        " heldContainers: " + heldContainers.size() +
        " heartbeats: " + numHeartbeats +
        " lastPreemptionHeartbeat: " + lastPreemptionHeartbeat +
        ((highestPriority != null) ?
            (" highestWaitingRequestWaitStartTime: " + requestTracker.getHighestPriorityWaitTimestamp() +
                " highestWaitingRequestPriority: " + highestPriority) : "");
  }

  @VisibleForTesting
  int getNumBlacklistedNodes() {
    return blacklistedNodes.size();
  }

  @VisibleForTesting
  Collection<HeldContainer> getSessionContainers() {
    return sessionContainers;
  }

  // Wrapper class to work around lack of blacklisting APIs in async client.
  // This can be removed once Tez requires YARN >= 2.7.0
  static class AMRMClientAsyncWrapper extends AMRMClientAsyncImpl<TaskRequest> {
    AMRMClientAsyncWrapper(AMRMClient<TaskRequest> syncClient, int intervalMs, CallbackHandler handler) {
      super(syncClient, intervalMs, handler);
    }

    public void updateBlacklist(List<String> additions, List<String> removals) {
      client.updateBlacklist(additions, removals);
    }
  }

  /**
   * A utility class to track a task allocation.
   */
  static class TaskRequest extends AMRMClient.ContainerRequest {
    final Object task;
    final int vertexIndex;
    final Object signature;
    final Object cookie;
    final ContainerId affinityContainerId;

    TaskRequest(Object task, int vertexIndex, Resource capability, String[] hosts, String[] racks,
        Priority priority, Object signature, Object cookie) {
      this(task, vertexIndex, capability, hosts, racks, priority, signature, cookie,  null);
    }

    TaskRequest(Object task, int vertexIndex, Resource capability, String[] hosts, String[] racks,
        Priority priority, Object signature, Object cookie, ContainerId affinityContainerId) {
      super(capability, hosts, racks, priority);
      this.task = task;
      this.vertexIndex = vertexIndex;
      this.signature = signature;
      this.cookie = cookie;
      this.affinityContainerId = affinityContainerId;
    }

    Object getTask() {
      return task;
    }

    int getVertexIndex() {
      return vertexIndex;
    }

    Object getContainerSignature() {
      return signature;
    }

    Object getCookie() {
      return cookie;
    }

    @Nullable
    ContainerId getAffinity() {
      return affinityContainerId;
    }

    boolean hasLocality() {
      List<String> nodes = getNodes();
      List<String> racks = getRacks();
      return (nodes != null && !nodes.isEmpty()) || (racks != null && !racks.isEmpty());
    }
  }

  private enum HeldContainerState {
    MATCHING_LOCAL(true),
    MATCHING_RACK(true),
    MATCHING_ANY(true),
    ASSIGNED(false),
    RELEASED(false);

    private static final EnumSet<HeldContainerState> MATCHES_LOCAL_STATES = EnumSet.of(
        HeldContainerState.MATCHING_LOCAL, HeldContainerState.MATCHING_RACK, HeldContainerState.MATCHING_ANY);
    private static final EnumSet<HeldContainerState> MATCHES_RACK_STATES = EnumSet.of(
        HeldContainerState.MATCHING_RACK, HeldContainerState.MATCHING_ANY);
    private static final EnumSet<HeldContainerState> MATCHES_ANY_STATES = EnumSet.of(HeldContainerState.MATCHING_ANY);

    private final boolean assignable;

    HeldContainerState(boolean assignable) {
      this.assignable = assignable;
    }

    boolean isAssignable() {
      return assignable;
    }
  }

  /**
   * Tracking for an allocated container.
   */
  @VisibleForTesting
  class HeldContainer implements Callable<Void> {
    final Container container;
    final String rack;
    @GuardedBy("DagAwareYarnTaskScheduler.this")
    HeldContainerState state = HeldContainerState.MATCHING_LOCAL;

    /** The Future received when scheduling an idle container for re-allocation at a later time. */
    @GuardedBy("DagAwareYarnTaskScheduler.this")
    Future<Void> future = null;

    /** The collection of task requests that have specified this container as a scheduling affinity. */
    @GuardedBy("DagAwareYarnTaskScheduler.this")
    Collection<TaskRequest> affinities = null;

    /**
     * The task request corresponding to the currently assigned task to this container.
     * This field is null when the container is not currently assigned.
     */
    @GuardedBy("DagAwareYarnTaskScheduler.this")
    TaskRequest assignedRequest = null;

    /** The task request corresponding to the last task that was assigned to this container. */
    @GuardedBy("DagAwareYarnTaskScheduler.this")
    TaskRequest lastRequest = null;

    /** The timestamp when the idle container will expire. 0 if the container is not idle. */
    @GuardedBy("DagAwareYarnTaskScheduler.this")
    long idleExpirationTimestamp = 0;

    /** The timestamp when this container was assigned. 0 if the container is not assigned. */
    @GuardedBy("DagAwareYarnTaskScheduler.this")
    long assignmentTimestamp = 0;

    HeldContainer(Container container) {
      this.container = container;
      this.rack = RackResolver.resolve(container.getNodeId().getHost()).getNetworkLocation();
    }

    HeldContainerState getState() {
      return state;
    }

    boolean isAssignable() {
      return state.isAssignable();
    }

    boolean isReleasedAndUsed() {
      return state == HeldContainerState.RELEASED && getLastTask() != null;
    }

    Container getContainer() {
      return container;
    }

    ContainerId getId() {
      return container.getId();
    }

    String getHost() {
      return container.getNodeId().getHost();
    }

    String getRack() {
      return rack;
    }

    Priority getPriority() {
      return container.getPriority();
    }

    Resource getCapability() {
      return container.getResource();
    }

    @Nullable
    Object getAssignedTask() {
      return assignedRequest != null ? assignedRequest.getTask() : null;
    }

    void assignTask(TaskRequest request) {
      assert state != HeldContainerState.ASSIGNED && state != HeldContainerState.RELEASED;
      if (assignedRequest != null) {
        LOG.error("Container {} assigned task {} but already running task {}",
            getId(), request.getTask(), assignedRequest.getTask());
      }
      assignedRequest = request;
      lastRequest = request;
      state = HeldContainerState.ASSIGNED;
      idleExpirationTimestamp = 0;
      assignmentTimestamp = now();
      if (future != null) {
        future.cancel(false);
        future = null;
      }
    }

    TaskRequest removeAssignment() {
      assert state == HeldContainerState.ASSIGNED;
      TaskRequest result = assignedRequest;
      assignedRequest = null;
      assignmentTimestamp = 0;
      state = HeldContainerState.MATCHING_LOCAL;
      return result;
    }

    void addAffinity(TaskRequest request) {
      if (affinities == null) {
        affinities = new HashSet<>();
      }
      affinities.add(request);
    }

    void removeAffinity(TaskRequest request) {
      if (affinities != null && affinities.remove(request) && affinities.isEmpty()) {
        affinities = null;
      }
    }

    int getNumAffinities() {
      return affinities != null ? affinities.size() : 0;
    }

    @Nullable
    Collection<TaskRequest> getAffinities() {
      return affinities;
    }

    void scheduleForReuse(long delayMillis) {
      assert state != HeldContainerState.ASSIGNED && state != HeldContainerState.RELEASED;
      try {
        if (future != null) {
          future.cancel(false);
        }
        future = reuseExecutor.schedule(this, delayMillis, TimeUnit.MILLISECONDS);
      } catch (RejectedExecutionException e) {
        if (!stopRequested) {
          LOG.error("Container {} could not be scheduled for reuse!", getId(), e);
        }
      }
    }

    @Nullable
    Object getSignature() {
      return lastRequest != null ? lastRequest.getContainerSignature() : null;
    }

    @Nullable
    Object getLastTask() {
      return lastRequest != null ? lastRequest.getTask() : null;
    }

    String getMatchingLocation() {
      switch (state) {
      case MATCHING_LOCAL:
        return getHost();
      case MATCHING_RACK:
        return getRack();
      case MATCHING_ANY:
        return ResourceRequest.ANY;
      default:
        throw new IllegalStateException("Container " + getId() + " trying to match in state " + state);
      }
    }

    void moveToNextMatchingLevel() {
      switch (state) {
      case MATCHING_LOCAL:
        if (reuseRackLocal) {
          state = HeldContainerState.MATCHING_RACK;
        }
        break;
      case MATCHING_RACK:
        if (reuseNonLocal) {
          state = HeldContainerState.MATCHING_ANY;
        }
        break;
      case MATCHING_ANY:
        break;
      default:
        throw new IllegalStateException("Container " + getId() + " trying to match in state " + state);
      }
    }

    boolean atMaxMatchLevel() {
      switch (state) {
      case MATCHING_LOCAL:
        return !reuseRackLocal;
      case MATCHING_RACK:
        return !reuseNonLocal;
      case MATCHING_ANY:
        return true;
      default:
        throw new IllegalStateException("Container " + getId() + " trying to match in state " + state);
      }
    }

    void resetMatchingLevel() {
      if (isAssignable()) {
        state = HeldContainerState.MATCHING_LOCAL;
      }
    }

    long getIdleExpirationTimestamp(long now) {
      if (idleExpirationTimestamp == 0) {
        if (idleContainerTimeoutMin > 0) {
          idleExpirationTimestamp = now + (idleContainerTimeoutMin == idleContainerTimeoutMax ? idleContainerTimeoutMin
            : random.nextLong(idleContainerTimeoutMin, idleContainerTimeoutMax));
        } else {
          idleExpirationTimestamp = Long.MAX_VALUE;
        }
      }
      return idleExpirationTimestamp;
    }

    long getAssignmentTimestamp() {
      return assignmentTimestamp;
    }

    boolean canFit(Resource capability) {
      Resource cr = container.getResource();
      return cr.getMemory() >= capability.getMemory() && cr.getVirtualCores() >= capability.getVirtualCores();
    }

    @Override
    public Void call() throws Exception {
      AMState appState = getContext().getAMState();
      boolean isSession = getContext().isSession();
      TaskRequest assigned = null;
      ContainerId released = null;
      synchronized (DagAwareYarnTaskScheduler.this) {
        future = null;
        if (isAssignable()) {
          moveToNextMatchingLevel();
          assigned = tryAssignReuseContainer(this, appState, isSession);
          if (assigned == null && isReleasedAndUsed()) {
            released = getId();
          }
        }
      }
      if (assigned != null) {
        informAppAboutAssignment(assigned, container);
      }
      if (released != null) {
        getContext().containerBeingReleased(released);
      }
      return null;
    }

    void released() {
      assert state != HeldContainerState.RELEASED;
      state = HeldContainerState.RELEASED;
      if (future != null) {
        future.cancel(false);
      }
      future = null;
    }
  }

  /**
   * Utility comparator to order containers by assignment timestamp from
   * most recent to least recent.
   */
  private static class PreemptOrderComparator implements Comparator<HeldContainer> {
    @Override
    public int compare(HeldContainer o1, HeldContainer o2) {
      long timestamp1 = o1.getAssignmentTimestamp();
      if (timestamp1 == 0) {
        timestamp1 = Long.MAX_VALUE;
      }
      long timestamp2 = o2.getAssignmentTimestamp();
      if (timestamp2 == 0) {
        timestamp2 = Long.MAX_VALUE;
      }
      return Long.compare(timestamp2, timestamp1);
    }
  }

  /**
   * Utility class for a request, container pair
   */
  private static class Assignment {
    final TaskRequest request;
    final Container container;

    Assignment(TaskRequest request, Container container) {
      this.request = request;
      this.container = container;
    }
  }

  /**
   * Utility class for a task, container exit status pair
   */
  private static class TaskStatus {
    final Object task;
    final ContainerStatus status;

    TaskStatus(Object task, ContainerStatus status) {
      this.task = task;
      this.status = status;
    }
  }

  /**
   * The task allocation request tracker tracks task allocations
   * and keeps statistics on which priorities have requests and which vertices
   * should be blocked from container reuse due to DAG topology.
   */
  private class RequestTracker {
    private final Map<Object, TaskRequest> requests = new HashMap<>();
    /** request map ordered by priority with highest priority first */
    private final NavigableMap<Priority, RequestPriorityStats> priorityStats =
        new TreeMap<>(Collections.reverseOrder());
    private Priority highestPriority = null;
    private long highestPriorityWaitTimestamp = 0;

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    @Nullable
    TaskRequest add(TaskRequest request) {
      TaskRequest oldRequest = requests.put(request.getTask(), request);
      Priority priority = request.getPriority();
      RequestPriorityStats stats = priorityStats.get(priority);
      if (stats == null) {
        stats = addStatsForPriority(priority);
      }
      ++stats.requestCount;
      if (request.hasLocality()) {
        ++stats.localityCount;
      }
      incrVertexTaskCount(priority, stats, request.getVertexIndex());

      if (oldRequest != null) {
        updateStatsForRemoval(oldRequest);
      }
      return oldRequest;
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    @Nullable
    TaskRequest remove(Object task) {
      TaskRequest request = requests.remove(task);
      if (request != null) {
        updateStatsForRemoval(request);
        return request;
      }
      return null;
    }

    private RequestPriorityStats addStatsForPriority(Priority priority) {
      BitSet allowedVerts = new BitSet(vertexDescendants.size());
      Entry<Priority,RequestPriorityStats> lowerEntry = priorityStats.lowerEntry(priority);
      if (lowerEntry != null) {
        // initialize the allowed vertices BitSet using the information derived
        // from the next higher priority entry
        RequestPriorityStats priorStats = lowerEntry.getValue();
        allowedVerts.or(priorStats.allowedVertices);
        allowedVerts.andNot(priorStats.descendants);
      } else {
        // no higher priority entry so this priority is currently the highest
        highestPriority = priority;
        highestPriorityWaitTimestamp = now();
        allowedVerts.set(0, vertexDescendants.size());
      }
      RequestPriorityStats stats = new RequestPriorityStats(vertexDescendants.size(), allowedVerts);
      priorityStats.put(priority, stats);
      return stats;
    }

    private void updateStatsForRemoval(TaskRequest request) {
      Priority priority = request.getPriority();
      RequestPriorityStats stats = priorityStats.get(priority);
      decrVertexTaskCount(priority, stats, request.getVertexIndex());
      --stats.requestCount;
      if (request.hasLocality()) {
        --stats.localityCount;
      }
      if (stats.requestCount == 0) {
        priorityStats.remove(priority);
        if (highestPriority.equals(priority)) {
          if (priorityStats.isEmpty()) {
            highestPriority = null;
            highestPriorityWaitTimestamp = 0;
          } else {
            highestPriority = priorityStats.firstKey();
            highestPriorityWaitTimestamp = now();
          }
        }
      }
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    boolean isEmpty() {
      return requests.isEmpty();
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    int getNumRequests() {
      return requests.size();
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    List<Object> getTasks() {
      return new ArrayList<>(requests.keySet());
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    Collection<Entry<Priority, RequestPriorityStats>> getStatsEntries() {
      return priorityStats.entrySet();
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    @Nullable
    Priority getHighestPriority() {
      if (priorityStats.isEmpty()) {
        return null;
      }
      return priorityStats.firstKey();
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    long getHighestPriorityWaitTimestamp() {
      return highestPriorityWaitTimestamp;
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    boolean isRequestBlocked(TaskRequest request) {
      Entry<Priority, RequestPriorityStats> entry = priorityStats.floorEntry(request.getPriority());
      if (entry != null) {
        RequestPriorityStats stats = entry.getValue();
        int vertexIndex = request.getVertexIndex();
        return !stats.allowedVertices.get(vertexIndex) || stats.descendants.get(vertexIndex);
      }
      return false;
    }

    private void incrVertexTaskCount(Priority priority, RequestPriorityStats stats, int vertexIndex) {
      Integer vertexIndexInt = vertexIndex;
      MutableInt taskCount = stats.vertexTaskCount.get(vertexIndexInt);
      if (taskCount != null) {
        taskCount.increment();
      } else {
        addVertexToRequestStats(priority, stats, vertexIndexInt);
      }
    }

    private void decrVertexTaskCount(Priority priority, RequestPriorityStats stats, int vertexIndex) {
      Integer vertexIndexInt = vertexIndex;
      MutableInt taskCount = stats.vertexTaskCount.get(vertexIndexInt);
      taskCount.decrement();
      if (taskCount.intValue() <= 0) {
        removeVertexFromRequestStats(priority, stats, vertexIndexInt);
      }
    }

    /**
     * Add a new vertex to a RequestPriorityStats.
     *
     * Adding a vertex to the request stats requires updating the stats descendants bitmask to include the descendants
     * of the new vertex and also updating the allowedVertices bitmask for all lower priority requests to prevent any
     * task request from a descendant vertex in the DAG from being allocated. This avoids assigning allocations to
     * lower priority requests when a higher priority request of an ancestor is still pending, but it allows lower
     * priority requests to be satisfied if higher priority requests are not ancestors. This is particularly useful
     * for DAGs that have independent trees of vertices or significant, parallel branches within a tree.
     *
     * Requests are blocked by taking the specified vertex's full descendant vertex bitmask in vertexDescendants and
     * clearing those bits for all lower priority requests. For the following example DAG where each vertex index
     * corresponds to its letter position (i.e.: A=0, B=1, C=2, etc.)
     *
     *       A
     *       |
     *   C---B----E
     *   |        |
     *   D        F
     *            |
     *          G---H
     *
     * Vertices F, G, and H are descendants of E but all other vertices are not. The vertexDescendants bitmask for
     * vertex E is therefore 11100000b or 0xE0. When the first vertex E task request arrives we need to disallow
     * requests for all descendants of E. That is accomplished by iterating through the request stats for all lower
     * priority requests and clearing the allowedVertex bits corresponding to the descendants,
     * i.e: allowedVertices = allowedVertices & ~descendants
     */
    private void addVertexToRequestStats(Priority priority, RequestPriorityStats stats, Integer vertexIndexInt) {
      // Creating a new vertex entry for this priority, so the allowed vertices for all
      // lower priorities need to be updated based on the descendants of the new vertex.
      stats.vertexTaskCount.put(vertexIndexInt, new MutableInt(1));
      int vertexIndex = vertexIndexInt;
      stats.vertices.set(vertexIndex);
      BitSet d = vertexDescendants.get(vertexIndex);
      stats.descendants.or(d);
      for (RequestPriorityStats lowerStat : priorityStats.tailMap(priority, false).values()) {
        lowerStat.allowedVertices.andNot(d);
      }
    }

    /**
     * Removes a vertex from a RequestPriorityStats.
     *
     * Removing a vertex is more expensive than adding a vertex. The stats contain bitmasks which only store on/off
     * values rather than reference counts. Therefore we must rebuild the descendants bitmasks from the remaining
     * vertices in the request stats. Once the new descendants mask is computed we then need to rebuild the
     * allowedVertices BitSet for all lower priority request stats in case the removal of this vertex unblocks lower
     * priority requests of a descendant vertex.
     *
     * Rebuilding allowedVertices for the lower priorities involves starting with the allowedVertices mask at the
     * current priority then masking off the descendants at each priority level encountered, accumulating the results.
     * Any descendants of a level will be blocked at all lower levels. See the addVertexToRequestStats documentation
     * for details on how vertices map to the descendants and allowedVertices bit masks.
     */
    private void removeVertexFromRequestStats(Priority priority, RequestPriorityStats stats, Integer vertexIndexInt) {
      stats.vertexTaskCount.remove(vertexIndexInt);
      int vertexIndex = vertexIndexInt;
      stats.vertices.clear(vertexIndex);

      // Rebuild the descendants BitSet for the remaining vertices at this priority.
      stats.descendants.clear();
      for (Integer vIndex : stats.vertexTaskCount.keySet()) {
        stats.descendants.or(vertexDescendants.get(vIndex));
      }

      // The allowedVertices for all lower priorities need to be recalculated where the vertex descendants at each
      // level are removed from the list of allowed vertices at all subsequent levels.
      Collection<RequestPriorityStats> tailStats = priorityStats.tailMap(priority, false).values();
      if (!tailStats.isEmpty()) {
        BitSet cumulativeAllowed = new BitSet(vertexDescendants.size());
        cumulativeAllowed.or(stats.allowedVertices);
        cumulativeAllowed.andNot(stats.descendants);
        for (RequestPriorityStats s : tailStats) {
          s.allowedVertices.clear();
          s.allowedVertices.or(cumulativeAllowed);
          cumulativeAllowed.andNot(s.descendants);
        }
      }
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    boolean isPreemptionDeadlineExpired() {
      return highestPriorityWaitTimestamp != 0
          && now() - highestPriorityWaitTimestamp > preemptionMaxWaitTime;
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    boolean fitsHighestPriorityRequest(Resource freeResources) {
      if (priorityStats.isEmpty()) {
        return true;
      }
      Priority priority = priorityStats.firstKey();
      List<? extends Collection> requestsList = client.getMatchingRequests(
          priority, ResourceRequest.ANY, freeResources);
      return !requestsList.isEmpty();
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    Resource getAmountToPreempt(int preemptionPercentage) {
      if (priorityStats.isEmpty()) {
        return Resources.none();
      }
      Priority priority = priorityStats.firstKey();
      List<? extends Collection<TaskRequest>> requestsList = client.getMatchingRequests(
          priority, ResourceRequest.ANY, Resources.unbounded());
      int numRequests = 0;
      for (Collection<TaskRequest> requests : requestsList) {
        numRequests += requests.size();
      }
      numRequests = (int) Math.ceil(numRequests * (preemptionPercentage / 100.f));
      Resource toPreempt = Resource.newInstance(0, 0);
      if (numRequests != 0) {
        outer_loop:
        for (Collection<TaskRequest> requests : requestsList) {
          for (TaskRequest request : requests) {
            Resources.addTo(toPreempt, request.getCapability());
            if (--numRequests == 0) {
              break outer_loop;
            }
          }
        }
      }
      return toPreempt;
    }

    // Create a new BitSet that represents all of the vertices that should not be
    // scheduled due to outstanding requests from higher priority predecessor vertices.
    @GuardedBy("DagAwareYarnTaskScheduler.this")
    BitSet createVertexBlockedSet() {
      BitSet blocked = new BitSet(vertexDescendants.size());
      Entry<Priority, RequestPriorityStats> entry = priorityStats.lastEntry();
      if (entry != null) {
        RequestPriorityStats stats = entry.getValue();
        blocked.or(stats.allowedVertices);
        blocked.flip(0, blocked.size());
        blocked.or(stats.descendants);
      }
      return blocked;
    }
  }

  /**
   * Tracks statistics on vertices that are requesting tasks at a particular priority
   */
  private static class RequestPriorityStats {
    /** Map from vertex ID to number of task requests for that vertex */
    final Map<Integer, MutableInt> vertexTaskCount = new HashMap<>();
    /** BitSet of vertices that have oustanding requests at this priority */
    final BitSet vertices;
    /** BitSet of vertices that are descendants of this vertex */
    final BitSet descendants;
    /**
     * BitSet of vertices that are allowed to be scheduled at this priority
     * (i.e.: no oustanding predecessors requesting at higher priorities)
     */
    final BitSet allowedVertices;
    int requestCount = 0;
    int localityCount = 0;

    RequestPriorityStats(int numTotalVertices, BitSet allowedVertices) {
      this.vertices = new BitSet(numTotalVertices);
      this.descendants = new BitSet(numTotalVertices);
      this.allowedVertices = allowedVertices;
    }
  }

  /**
   * Tracks idle containers and facilitates faster matching of task requests
   * against those containers given a desired location.
   */
  private static class IdleContainerTracker {
    /**
     * Map of location ID (e.g.: a specific host, rack, or ANY) to set of
     * idle containers matching that location
     */
    final Map<String, Set<HeldContainer>> containersByLocation = new HashMap<>();
    int numContainers = 0;

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    void add(HeldContainer hc) {
      add(hc, hc.getHost());
      add(hc, hc.getRack());
      add(hc, ResourceRequest.ANY);
      ++numContainers;
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    void remove(HeldContainer hc) {
      remove(hc, hc.getHost());
      remove(hc, hc.getRack());
      remove(hc, ResourceRequest.ANY);
      --numContainers;
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    int getNumContainers() {
      return numContainers;
    }

    private void add(HeldContainer hc, String location) {
      Set<HeldContainer> containers = containersByLocation.get(location);
      if (containers == null) {
        containers = new HashSet<>();
        containersByLocation.put(location, containers);
      }
      containers.add(hc);
    }

    private void remove(HeldContainer hc, String location) {
      Set<HeldContainer> containers = containersByLocation.get(location);
      if (containers != null) {
        if (containers.remove(hc) && containers.isEmpty()) {
          containersByLocation.remove(location);
        }
      }
    }

    @GuardedBy("DagAwareYarnTaskScheduler.this")
    @Nullable
    Set<HeldContainer> getByLocation(String location) {
      return containersByLocation.get(location);
    }
  }

  private interface ResourceCalculator {
    boolean anyAvailable(Resource rsrc);
    void deductFrom(Resource total, Resource toSubtract);
  }

  /**
   * ResourceCalculator for memory-only allocation
   */
  private static class MemResourceCalculator implements ResourceCalculator {

    @Override
    public boolean anyAvailable(Resource rsrc) {
      return rsrc.getMemory() > 0;
    }

    @Override
    public void deductFrom(Resource total, Resource toSubtract) {
      total.setMemory(total.getMemory() - toSubtract.getMemory());
    }
  }

  /**
   * ResourceCalculator for memory and vcore allocation
   */
  private static class MemCpuResourceCalculator extends MemResourceCalculator {

    @Override
    public boolean anyAvailable(Resource rsrc) {
      return super.anyAvailable(rsrc) || rsrc.getVirtualCores() > 0;
    }

    @Override
    public void deductFrom(Resource total, Resource toSubtract) {
      super.deductFrom(total, toSubtract);
      total.setVirtualCores(total.getVirtualCores() - toSubtract.getVirtualCores());
    }
  }

  /**
   * Scheduled thread pool executor that logs any errors that escape the worker thread.
   * This can be replaced with HadoopThreadPoolExecutor once Tez requires Hadoop 2.8 or later.
   */
  static class ReuseContainerExecutor extends ScheduledThreadPoolExecutor {
    ReuseContainerExecutor() {
      super(1, new ThreadFactoryBuilder().setNameFormat("ReuseContainerExecutor #%d").build());
      setRemoveOnCancelPolicy(true);
      setExecuteExistingDelayedTasksAfterShutdownPolicy(false);
    }

    @Override
    protected void afterExecute(Runnable r, Throwable t) {
      super.afterExecute(r, t);

      if (t == null && r instanceof Future<?>) {
        try {
          ((Future<?>) r).get();
        } catch (ExecutionException ee) {
          LOG.warn("Execution exception when running task in {}",  Thread.currentThread().getName());
          t = ee.getCause();
        } catch (InterruptedException ie) {
          LOG.warn("Thread ({}) interrupted: ", Thread.currentThread(), ie);
          Thread.currentThread().interrupt();
        } catch (Throwable throwable) {
          t = throwable;
        }
      }

      if (t != null) {
        LOG.warn("Caught exception in thread {}", Thread.currentThread().getName(), t);
      }
    }
  }
}