/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.twill.internal.appmaster;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Strings;
import com.google.common.base.Supplier;
import com.google.common.collect.DiscreteDomains;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ranges;
import com.google.common.collect.Sets;
import com.google.common.reflect.TypeToken;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.util.Records;
import org.apache.twill.api.Command;
import org.apache.twill.api.EventHandler;
import org.apache.twill.api.EventHandlerContext;
import org.apache.twill.api.EventHandlerSpecification;
import org.apache.twill.api.LocalFile;
import org.apache.twill.api.ResourceReport;
import org.apache.twill.api.ResourceSpecification;
import org.apache.twill.api.RunId;
import org.apache.twill.api.RuntimeSpecification;
import org.apache.twill.api.TwillRunResources;
import org.apache.twill.api.TwillSpecification;
import org.apache.twill.common.Threads;
import org.apache.twill.filesystem.Location;
import org.apache.twill.internal.Constants;
import org.apache.twill.internal.ContainerInfo;
import org.apache.twill.internal.DefaultTwillRunResources;
import org.apache.twill.internal.EnvKeys;
import org.apache.twill.internal.JvmOptions;
import org.apache.twill.internal.ProcessLauncher;
import org.apache.twill.internal.TwillContainerLauncher;
import org.apache.twill.internal.TwillRuntimeSpecification;
import org.apache.twill.internal.json.LocalFileCodec;
import org.apache.twill.internal.json.TwillRuntimeSpecificationAdapter;
import org.apache.twill.internal.state.Message;
import org.apache.twill.internal.state.SystemMessages;
import org.apache.twill.internal.utils.Instances;
import org.apache.twill.internal.utils.Resources;
import org.apache.twill.internal.yarn.AbstractYarnTwillService;
import org.apache.twill.internal.yarn.YarnAMClient;
import org.apache.twill.internal.yarn.YarnContainerInfo;
import org.apache.twill.internal.yarn.YarnContainerStatus;
import org.apache.twill.internal.yarn.YarnUtils;
import org.apache.twill.zookeeper.ZKClient;
import org.apache.twill.zookeeper.ZKClients;
import org.apache.twill.zookeeper.ZKOperations;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/**
 * The class that acts as {@code ApplicationMaster} for Twill applications.
 */
public final class ApplicationMasterService extends AbstractYarnTwillService implements Supplier<ResourceReport> {
  /**
   * Final status of this service when it stops.
   */
  private enum StopStatus {
    COMPLETED, // All containers complete
    ABORTED // Aborted because of timeout
  }

  private static final Logger LOG = LoggerFactory.getLogger(ApplicationMasterService.class);
  private static final Gson GSON = new GsonBuilder()
    .serializeNulls()
    .registerTypeAdapter(LocalFile.class, new LocalFileCodec())
    .create();

  // Copied from org.apache.hadoop.yarn.security.AMRMTokenIdentifier.KIND_NAME since it's missing in Hadoop-2.0
  private static final Text AMRM_TOKEN_KIND_NAME = new Text("YARN_AM_RM_TOKEN");

  private final RunId runId;
  private final ZKClient zkClient;
  private final TwillSpecification twillSpec;
  private final ApplicationMasterLiveNodeData amLiveNode;
  private final RunningContainers runningContainers;
  private final ExpectedContainers expectedContainers;
  private final YarnAMClient amClient;
  private final JvmOptions jvmOpts;
  private final EventHandler eventHandler;
  private final Location applicationLocation;
  private final PlacementPolicyManager placementPolicyManager;
  private final Map<String, Map<String, String>> environments;
  private final TwillRuntimeSpecification twillRuntimeSpec;

  private volatile StopStatus stopStatus;
  private volatile boolean stopped;
  private Queue<RunnableContainerRequest> runnableContainerRequests;
  private ExecutorService instanceChangeExecutor;

  public ApplicationMasterService(RunId runId, ZKClient zkClient,
                                  TwillRuntimeSpecification twillRuntimeSpec, YarnAMClient amClient,
                                  Configuration config, Location applicationLocation) throws Exception {
    super(zkClient, runId, config, applicationLocation);

    this.runId = runId;
    this.twillRuntimeSpec = twillRuntimeSpec;
    this.zkClient = zkClient;
    this.applicationLocation = applicationLocation;
    this.amClient = amClient;
    this.credentials = createCredentials();
    this.jvmOpts = loadJvmOptions();
    this.twillSpec = twillRuntimeSpec.getTwillSpecification();
    this.placementPolicyManager = new PlacementPolicyManager(twillSpec.getPlacementPolicies());
    this.environments = getEnvironments();

    this.amLiveNode = new ApplicationMasterLiveNodeData(Integer.parseInt(System.getenv(EnvKeys.YARN_APP_ID)),
                                                        Long.parseLong(System.getenv(EnvKeys.YARN_APP_ID_CLUSTER_TIME)),
                                                        amClient.getContainerId().toString(), getLocalizeFiles(),
                                                        twillRuntimeSpec.getKafkaZKConnect());

    this.expectedContainers = new ExpectedContainers(twillSpec);
    this.eventHandler = createEventHandler(twillSpec);
    this.runningContainers = createRunningContainers(amClient.getContainerId(), amClient.getHost());
  }

  private JvmOptions loadJvmOptions() throws IOException {
    final File jvmOptsFile = new File(Constants.Files.RUNTIME_CONFIG_JAR, Constants.Files.JVM_OPTIONS);
    if (!jvmOptsFile.exists()) {
      return new JvmOptions("", Collections.<String, String>emptyMap(), JvmOptions.DebugOptions.NO_DEBUG);
    }
    try (Reader reader = Files.newBufferedReader(jvmOptsFile.toPath(), StandardCharsets.UTF_8)) {
      return GSON.fromJson(reader, JvmOptions.class);
    }
  }

  @SuppressWarnings("unchecked")
  private EventHandler createEventHandler(TwillSpecification twillSpec) throws ClassNotFoundException {
    // Should be able to load by this class ClassLoader, as they packaged in the same jar.
    EventHandlerSpecification handlerSpec = twillSpec.getEventHandler();
    if (handlerSpec == null) {
      // if no handler is specified, return an EventHandler with no-op
      return new EventHandler() {};
    }

    Class<?> handlerClass = getClass().getClassLoader().loadClass(handlerSpec.getClassName());
    Preconditions.checkArgument(EventHandler.class.isAssignableFrom(handlerClass),
                                "Class {} does not implements {}",
                                handlerClass, EventHandler.class.getName());
    final EventHandler delegate = Instances.newInstance((Class<? extends EventHandler>) handlerClass);
    // wrap all calls to the delegate EventHandler methods except initialize so that all errors will be caught
    return new EventHandler() {

      @Override
      public void initialize(EventHandlerContext context) {
        delegate.initialize(context);
      }

      @Override
      public void started() {
        try {
          delegate.started();
        } catch (Throwable t) {
          LOG.warn("Exception raised when calling {}.started()", delegate.getClass().getName(), t);
        }
      }

      @Override
      public void containerLaunched(String runnableName, int instanceId, String containerId) {
        try {
          delegate.containerLaunched(runnableName, instanceId, containerId);
        } catch (Throwable t) {
          LOG.warn("Exception raised when calling {}.containerLaunched(String, int, String)",
                   delegate.getClass().getName(), t);
        }
      }

      @Override
      public void containerStopped(String runnableName, int instanceId, String containerId, int exitStatus) {
        try {
          delegate.containerStopped(runnableName, instanceId, containerId, exitStatus);
        } catch (Throwable t) {
          LOG.warn("Exception raised when calling {}.containerStopped(String, int, String, int)",
                   delegate.getClass().getName(), t);
        }
      }

      @Override
      public void completed() {
        try {
          delegate.completed();
        } catch (Throwable t) {
          LOG.warn("Exception raised when calling {}.completed()", delegate.getClass().getName(), t);
        }
      }

      @Override
      public void killed() {
        try {
          delegate.killed();
        } catch (Throwable t) {
          LOG.warn("Exception raised when calling {}.killed()", delegate.getClass().getName(), t);
        }
      }

      @Override
      public void aborted() {
        try {
          delegate.aborted();
        } catch (Throwable t) {
          LOG.warn("Exception raised when calling {}.aborted()", delegate.getClass().getName(), t);
        }
      }

      @Override
      public void destroy() {
        try {
          delegate.destroy();
        } catch (Throwable t) {
          LOG.warn("Exception raised when calling {}.destroy()", delegate.getClass().getName(), t);
        }
      }

      @Override
      public TimeoutAction launchTimeout(Iterable<TimeoutEvent> timeoutEvents) {
        try {
          return delegate.launchTimeout(timeoutEvents);
        } catch (Throwable t) {
          LOG.warn("Exception raised when calling {}.launchTimeout(Iterable<TimeoutEvent>)",
                   delegate.getClass().getName(), t);
        }
        // call super.launchTimeout in case of any errors from the delegate
        return super.launchTimeout(timeoutEvents);
      }
    };
  }

  private RunningContainers createRunningContainers(ContainerId appMasterContainerId,
                                                    String appMasterHost) throws Exception {
    int containerMemoryMB = Integer.parseInt(System.getenv(EnvKeys.YARN_CONTAINER_MEMORY_MB));

    // We can't get the -Xmx easily, so just recompute the -Xmx in the same way that the client does
    int maxHeapMemoryMB = Resources.computeMaxHeapSize(containerMemoryMB,
                                                       twillRuntimeSpec.getAMReservedMemory(),
                                                       twillRuntimeSpec.getAMMinHeapRatio());
    TwillRunResources appMasterResources = new DefaultTwillRunResources(
      0,
      appMasterContainerId.toString(),
      Integer.parseInt(System.getenv(EnvKeys.YARN_CONTAINER_VIRTUAL_CORES)),
      containerMemoryMB,
      maxHeapMemoryMB,
      appMasterHost, null);
    String appId = appMasterContainerId.getApplicationAttemptId().getApplicationId().toString();
    return new RunningContainers(twillRuntimeSpec, appId, appMasterResources, zkClient, applicationLocation,
                                 twillSpec.getRunnables(), eventHandler);
  }

  @Override
  public ResourceReport get() {
    return runningContainers.getResourceReport();
  }

  @Override
  protected void doStart() throws Exception {
    LOG.info("Start application master with spec: {}",
      TwillRuntimeSpecificationAdapter.create().toJson(twillRuntimeSpec));

    // initialize the event handler, if it fails, it will fail the application.
    eventHandler.initialize(new BasicEventHandlerContext(twillRuntimeSpec));
    // call event handler started.
    eventHandler.started();

    instanceChangeExecutor = Executors.newSingleThreadExecutor(Threads.createDaemonThreadFactory("instanceChanger"));

    // Creates ZK path for runnable. It's ok if the path already exists.
    // That's for the case when the AM get killed and restarted
    ZKOperations.ignoreError(
      zkClient.create("/" + runId.getId() + "/runnables", null, CreateMode.PERSISTENT),
      KeeperException.NodeExistsException.class, null)
      .get();
    runningContainers.addWatcher(Constants.DISCOVERY_PATH_PREFIX);
    runnableContainerRequests = initContainerRequests();
  }

  @Override
  protected void doStop() throws Exception {
    Thread.interrupted();     // This is just to clear the interrupt flag

    LOG.info("Stop application master with spec: {}",
             TwillRuntimeSpecificationAdapter.create().toJson(twillRuntimeSpec));

    instanceChangeExecutor.shutdownNow();

    // For checking if all containers are stopped.
    final Set<String> ids = Sets.newHashSet(runningContainers.getContainerIds());
    final YarnAMClient.AllocateHandler handler = new YarnAMClient.AllocateHandler() {
      @Override
      public void acquired(List<? extends ProcessLauncher<YarnContainerInfo>> launchers) {
        // no-op
      }

      @Override
      public void completed(List<YarnContainerStatus> completed) {
        for (YarnContainerStatus status : completed) {
          handleCompleted(completed);
          ids.remove(status.getContainerId());
        }
      }
    };

    // Handle heartbeats during shutdown because runningContainers.stopAll() waits until
    // handleCompleted() is called for every stopped runnable
    ExecutorService stopPoller = Executors.newSingleThreadExecutor(Threads.createDaemonThreadFactory("stopPoller"));
    stopPoller.execute(new Runnable() {
      @Override
      public void run() {
        while (!ids.isEmpty()) {
          try {
            amClient.allocate(0.0f, handler);
            if (!ids.isEmpty()) {
              TimeUnit.SECONDS.sleep(1);
            }
          } catch (Exception e) {
            LOG.error("Got exception while getting heartbeat", e);
          }
        }
      }
    });

    // runningContainers.stopAll() will wait for all the running runnables to stop or kill them after a timeout
    runningContainers.stopAll();
    // Since all the runnables are now stopped, it is okay to stop the poller.
    stopPoller.shutdownNow();
    cleanupDir();
    if (stopStatus == null) {
      // if finalStatus is not set, the application must be stopped by a SystemMessages#STOP_COMMAND
      eventHandler.killed();
    } else {
      switch (stopStatus) {
        case COMPLETED:
          eventHandler.completed();
          break;
        case ABORTED:
          eventHandler.aborted();
          break;
        default:
          // should never reach here
          LOG.error("Unsupported FinalStatus '{}'", stopStatus.name());
      }
    }
    // call event handler destroy
    eventHandler.destroy();
  }

  @Override
  protected Object getLiveNodeData() {
    return amLiveNode;
  }

  @Override
  protected Gson getLiveNodeGson() {
    return GSON;
  }

  @Override
  public ListenableFuture<String> onReceived(String messageId, Message message) {
    LOG.debug("Message received: {} {}.", messageId, message);

    SettableFuture<String> result = SettableFuture.create();
    Runnable completion = getMessageCompletion(messageId, result);

    if (handleSecureStoreUpdate(message)) {
      runningContainers.sendToAll(message, completion);
      return result;
    }

    if (handleSetInstances(message, completion)) {
      return result;
    }

    if (handleRestartRunnablesInstances(message, completion)) {
      return result;
    }

    if (handleLogLevelMessages(message, completion)) {
      return result;
    }

    // Replicate messages to all runnables
    if (message.getScope() == Message.Scope.ALL_RUNNABLE) {
      runningContainers.sendToAll(message, completion);
      return result;
    }

    // Replicate message to a particular runnable.
    if (message.getScope() == Message.Scope.RUNNABLE) {
      runningContainers.sendToRunnable(message.getRunnableName(), message, completion);
      return result;
    }

    LOG.info("Message ignored. {}", message);
    return Futures.immediateFuture(messageId);
  }

  @Override
  protected void triggerShutdown() {
    stopped = true;
  }

  private void cleanupDir() {
    try {
      if (applicationLocation.delete(true)) {
        LOG.info("Application directory deleted: {}", applicationLocation);
      } else {
        LOG.warn("Failed to cleanup directory {}.", applicationLocation);
      }
    } catch (Exception e) {
      LOG.warn("Exception while cleanup directory {}.", applicationLocation, e);
    }
  }

  @Override
  protected void doRun() throws Exception {
    // The main loop
    Map.Entry<AllocationSpecification, ? extends Collection<RuntimeSpecification>> currentRequest = null;
    final Queue<ProvisionRequest> provisioning = Lists.newLinkedList();

    YarnAMClient.AllocateHandler allocateHandler = new YarnAMClient.AllocateHandler() {
      @Override
      public void acquired(List<? extends ProcessLauncher<YarnContainerInfo>> launchers) {
        launchRunnable(launchers, provisioning);
      }

      @Override
      public void completed(List<YarnContainerStatus> completed) {
        handleCompleted(completed);
      }
    };

    long requestStartTime = 0;
    boolean isRequestRelaxed = false;
    long nextTimeoutCheck = System.currentTimeMillis() + Constants.PROVISION_TIMEOUT;
    while (!stopped) {
      TimeUnit.SECONDS.sleep(1);

      try {
        // Call allocate. It has to be made at first in order to be able to get cluster resource availability.
        amClient.allocate(0.0f, allocateHandler);
      } catch (Exception e) {
        LOG.warn("Exception raised when making heartbeat to RM. Will be retried in next heartbeat.", e);
      }

      // Looks for containers requests.
      if (provisioning.isEmpty() && runnableContainerRequests.isEmpty() && runningContainers.isEmpty()) {
        LOG.info("All containers completed. Shutting down application master.");
        stopStatus = StopStatus.COMPLETED;
        break;
      }

      // If nothing is in provisioning, and no pending request, move to next one
      if (provisioning.isEmpty() && currentRequest == null && !runnableContainerRequests.isEmpty()) {
        RunnableContainerRequest containerRequest = runnableContainerRequests.peek();
        // If the request at the head of the request queue is not yet ready, move it to the end of the queue
        // so that it won't block requests that are already ready
        if (!containerRequest.isReadyToBeProvisioned()) {
          LOG.debug("Request not ready: {}", containerRequest);
          runnableContainerRequests.add(runnableContainerRequests.poll());
          continue;
        }

        currentRequest = containerRequest.takeRequest();
        if (currentRequest == null) {
          // All different types of resource request from current order is done, move to next one
          // TODO: Need to handle order type as well
          runnableContainerRequests.poll();
          continue;
        }
      }

      // Nothing in provision, makes the next batch of provision request
      if (provisioning.isEmpty() && currentRequest != null) {
        manageBlacklist(currentRequest);
        addContainerRequests(currentRequest.getKey().getResource(), currentRequest.getValue(), provisioning,
                             currentRequest.getKey().getType());
        currentRequest = null;
        requestStartTime = System.currentTimeMillis();
        isRequestRelaxed = false;
      }

      // Check for provision request timeout i.e. check if any provision request has been pending
      // for more than the designated time. On timeout, relax the request constraints.
      if (!provisioning.isEmpty() && !isRequestRelaxed &&
        (System.currentTimeMillis() - requestStartTime) > Constants.CONSTRAINED_PROVISION_REQUEST_TIMEOUT) {
        LOG.info("Relaxing provisioning constraints for request {}", provisioning.peek().getRequestId());
        // Clear the blacklist for the pending provision request(s).
        amClient.clearBlacklist();
        isRequestRelaxed = true;
      }

      nextTimeoutCheck = checkProvisionTimeout(nextTimeoutCheck);
    }
  }

  /**
   * Manage Blacklist for a given request.
   */
  private void manageBlacklist(Map.Entry<AllocationSpecification, ? extends Collection<RuntimeSpecification>> request) {
    amClient.clearBlacklist();

    //Check the allocation strategy
    AllocationSpecification allocationSpec = request.getKey();
    if (!allocationSpec.getType().equals(AllocationSpecification.Type.ALLOCATE_ONE_INSTANCE_AT_A_TIME)) {
      return;
    }

    //Check the placement policy
    String runnableName = allocationSpec.getRunnableName();
    TwillSpecification.PlacementPolicy placementPolicy = placementPolicyManager.getPlacementPolicy(runnableName);
    if (placementPolicy == null || placementPolicy.getType() != TwillSpecification.PlacementPolicy.Type.DISTRIBUTED) {
      return;
    }

    //Update blacklist with hosts which are running DISTRIBUTED runnables
    for (String runnable : placementPolicy.getNames()) {
      for (ContainerInfo containerInfo : runningContainers.getContainerInfo(runnable)) {
        // Yarn Resource Manager may include port in the node name depending on the setting
        // YarnConfiguration.RM_SCHEDULER_INCLUDE_PORT_IN_NODE_NAME. It is safe to add both
        // the names (with and without port) to the blacklist.
        LOG.debug("Adding {} to host blacklist", containerInfo.getHost().getHostName());
        amClient.addToBlacklist(containerInfo.getHost().getHostName());
        amClient.addToBlacklist(containerInfo.getHost().getHostName() + ":" + containerInfo.getPort());
      }
    }
  }

  /**
   * Handling containers that are completed.
   */
  private void handleCompleted(List<YarnContainerStatus> completedContainersStatuses) {
    Multiset<String> restartRunnables = HashMultiset.create();
    for (YarnContainerStatus status : completedContainersStatuses) {
      LOG.info("Container {} completed with {}:{}.",
               status.getContainerId(), status.getState(), status.getDiagnostics());
      runningContainers.handleCompleted(status, restartRunnables);
    }

    for (Multiset.Entry<String> entry : restartRunnables.entrySet()) {
      LOG.info("Re-request container for {} with {} instances.", entry.getElement(), entry.getCount());
      runnableContainerRequests.add(createRunnableContainerRequest(entry.getElement(),  entry.getCount()));
    }

    // For all runnables that needs to re-request for containers, update the expected count timestamp
    // so that the EventHandler would triggered with the right expiration timestamp.
    expectedContainers.updateRequestTime(restartRunnables.elementSet());
  }

  /**
   * Check for containers provision timeout and invoke eventHandler if necessary.
   *
   * @return the timestamp for the next time this method needs to be called.
   */
  private long checkProvisionTimeout(long nextTimeoutCheck) {
    if (System.currentTimeMillis() < nextTimeoutCheck) {
      return nextTimeoutCheck;
    }

    // Invoke event handler for provision request timeout
    Map<String, ExpectedContainers.ExpectedCount> expiredRequests = expectedContainers.getAll();
    Map<String, Integer> runningCounts = runningContainers.countAll();
    Map<String, Integer> completedContainerCount = runningContainers.getCompletedContainerCount();

    List<EventHandler.TimeoutEvent> timeoutEvents = Lists.newArrayList();
    for (Map.Entry<String, ExpectedContainers.ExpectedCount> entry : expiredRequests.entrySet()) {
      String runnableName = entry.getKey();
      ExpectedContainers.ExpectedCount expectedCount = entry.getValue();
      int runningCount = runningCounts.containsKey(runnableName) ? runningCounts.get(runnableName) : 0;
      int completedCount = completedContainerCount.containsKey(runnableName) ?
        completedContainerCount.get(runnableName) : 0;
      if (expectedCount.getCount() > runningCount + completedCount) {
        timeoutEvents.add(new EventHandler.TimeoutEvent(runnableName, expectedCount.getCount(),
                                                        runningCount, expectedCount.getTimestamp()));
      }
    }

    if (!timeoutEvents.isEmpty()) {
      EventHandler.TimeoutAction action = eventHandler.launchTimeout(timeoutEvents);
      try {
        if (action.getTimeout() < 0) {
          // Abort application
          stopStatus = StopStatus.ABORTED;
          stop();
        } else {
          return nextTimeoutCheck + action.getTimeout();
        }
      } catch (Throwable t) {
        LOG.warn("Exception when handling TimeoutAction.", t);
      }
    }
    return nextTimeoutCheck + Constants.PROVISION_TIMEOUT;
  }

  private Credentials createCredentials() {
    Credentials credentials = new Credentials();
    if (!UserGroupInformation.isSecurityEnabled()) {
      return credentials;
    }

    try {
      credentials.addAll(UserGroupInformation.getCurrentUser().getCredentials());

      // Remove the AM->RM tokens
      Iterator<Token<?>> iter = credentials.getAllTokens().iterator();
      while (iter.hasNext()) {
        Token<?> token = iter.next();
        if (token.getKind().equals(AMRM_TOKEN_KIND_NAME)) {
          iter.remove();
        }
      }
    } catch (IOException e) {
      LOG.warn("Failed to get current user. No credentials will be provided to containers.", e);
    }

    return credentials;
  }

  private Queue<RunnableContainerRequest> initContainerRequests() {
    // Orderly stores container requests.
    Queue<RunnableContainerRequest> requests = new ConcurrentLinkedQueue<>();
    // For each order in the twillSpec, create container request for runnables, depending on Placement policy.
    for (TwillSpecification.Order order : twillSpec.getOrders()) {
      Set<String> distributedRunnables = Sets.intersection(placementPolicyManager.getDistributedRunnables(),
                                                           order.getNames());
      Set<String> defaultRunnables = Sets.difference(order.getNames(), distributedRunnables);

      Map<AllocationSpecification, Collection<RuntimeSpecification>> requestsMap = Maps.newHashMap();
      for (String runnableName : distributedRunnables) {
        RuntimeSpecification runtimeSpec = twillSpec.getRunnables().get(runnableName);
        Resource capability = createCapability(runtimeSpec.getResourceSpecification());
        for (int instanceId = 0; instanceId < runtimeSpec.getResourceSpecification().getInstances(); instanceId++) {
          AllocationSpecification allocationSpecification =
            new AllocationSpecification(capability, AllocationSpecification.Type.ALLOCATE_ONE_INSTANCE_AT_A_TIME,
                                        runnableName, instanceId);
          addAllocationSpecification(allocationSpecification, requestsMap, runtimeSpec);
        }
      }
      for (String runnableName : defaultRunnables) {
        RuntimeSpecification runtimeSpec = twillSpec.getRunnables().get(runnableName);
        Resource capability = createCapability(runtimeSpec.getResourceSpecification());
        AllocationSpecification allocationSpecification = new AllocationSpecification(capability);
        addAllocationSpecification(allocationSpecification, requestsMap, runtimeSpec);
      }
      requests.add(new RunnableContainerRequest(order.getType(), requestsMap));
    }
    return requests;
  }

  /**
   * Helper method to create {@link org.apache.twill.internal.appmaster.RunnableContainerRequest}.
   */
  private void addAllocationSpecification(AllocationSpecification allocationSpecification,
                                          Map<AllocationSpecification, Collection<RuntimeSpecification>> map,
                                          RuntimeSpecification runtimeSpec) {
    if (!map.containsKey(allocationSpecification)) {
      map.put(allocationSpecification, Lists.<RuntimeSpecification>newLinkedList());
    }
    map.get(allocationSpecification).add(runtimeSpec);
  }

  /**
   * Adds container requests with the given resource capability for each runtime.
   */
  private void addContainerRequests(Resource capability,
                                    Collection<RuntimeSpecification> runtimeSpecs,
                                    Queue<ProvisionRequest> provisioning,
                                    AllocationSpecification.Type allocationType) {
    for (RuntimeSpecification runtimeSpec : runtimeSpecs) {
      String name = runtimeSpec.getName();
      int newContainers = expectedContainers.getExpected(name) - runningContainers.count(name);
      if (newContainers > 0) {
        if (allocationType.equals(AllocationSpecification.Type.ALLOCATE_ONE_INSTANCE_AT_A_TIME)) {
          //Spawning 1 instance at a time
          newContainers = 1;
        }

        // TODO: Allow user to set priority?
        LOG.info("Request {} containers with capability {} for runnable {}", newContainers, capability, name);
        YarnAMClient.ContainerRequestBuilder builder = amClient.addContainerRequest(capability, newContainers);
        builder.setPriority(0);

        TwillSpecification.PlacementPolicy placementPolicy = placementPolicyManager.getPlacementPolicy(name);
        if (placementPolicy != null) {
          builder.addHosts(placementPolicy.getHosts())
                 .addRacks(placementPolicy.getRacks());
        }

        String requestId = builder.apply();
        provisioning.add(new ProvisionRequest(runtimeSpec, requestId, newContainers, allocationType));
      }
    }
  }

  /**
   * Launches runnables in the provisioned containers.
   */
  private void launchRunnable(List<? extends ProcessLauncher<YarnContainerInfo>> launchers,
                              Queue<ProvisionRequest> provisioning) {
    for (ProcessLauncher<YarnContainerInfo> processLauncher : launchers) {
      LOG.info("Container allocated: {}", processLauncher.getContainerInfo().<Object>getContainer());
      ProvisionRequest provisionRequest = provisioning.peek();
      if (provisionRequest == null) {
        continue;
      }

      String runnableName = provisionRequest.getRuntimeSpec().getName();
      LOG.info("Starting runnable {} in {}", runnableName, processLauncher.getContainerInfo().getContainer());

      int containerCount = expectedContainers.getExpected(runnableName);

      // Setup container environment variables
      Map<String, String> env = new LinkedHashMap<>();
      if (environments.containsKey(runnableName)) {
        env.putAll(environments.get(runnableName));
      }

      ProcessLauncher.PrepareLaunchContext launchContext = processLauncher.prepareLaunch(env,
                                                                                         amLiveNode.getLocalFiles(),
                                                                                         credentials);
      TwillContainerLauncher launcher = new TwillContainerLauncher(
        twillSpec.getRunnables().get(runnableName), processLauncher.getContainerInfo(), launchContext,
        ZKClients.namespace(zkClient, getZKNamespace(runnableName)),
        containerCount, jvmOpts, twillRuntimeSpec.getReservedMemory(runnableName),
        twillRuntimeSpec.getMinHeapRatio(runnableName), getSecureStoreLocation());

      runningContainers.start(runnableName, processLauncher.getContainerInfo(), launcher);

      // Need to call complete to workaround bug in YARN AMRMClient
      if (provisionRequest.containerAcquired()) {
        amClient.completeContainerRequest(provisionRequest.getRequestId());
      }

      /*
       * The provisionRequest will either contain a single container (ALLOCATE_ONE_INSTANCE_AT_A_TIME), or all the
       * containers to satisfy the expectedContainers count. In the later case, the provision request is complete once
       * all the containers have run at which point we poll() to remove the provisioning request.
       */
      if (expectedContainers.getExpected(runnableName) == runningContainers.count(runnableName) ||
        provisioning.peek().getType().equals(AllocationSpecification.Type.ALLOCATE_ONE_INSTANCE_AT_A_TIME)) {
        provisioning.poll();
      }
      if (expectedContainers.getExpected(runnableName) == runningContainers.count(runnableName)) {
        LOG.info("Runnable {} fully provisioned with {} instances.", runnableName, containerCount);
      }
    }
  }

  private List<LocalFile> getLocalizeFiles() throws IOException {
    try (Reader reader = Files.newBufferedReader(Paths.get(Constants.Files.LOCALIZE_FILES), StandardCharsets.UTF_8)) {
      return new GsonBuilder().registerTypeAdapter(LocalFile.class, new LocalFileCodec())
        .create().fromJson(reader, new TypeToken<List<LocalFile>>() {
        }.getType());
    }
  }

  private Map<String, Map<String, String>> getEnvironments() throws IOException {
    Path envFile = Paths.get(Constants.Files.RUNTIME_CONFIG_JAR, Constants.Files.ENVIRONMENTS);
    if (!Files.exists(envFile)) {
      return new HashMap<>();
    }

    try (Reader reader = Files.newBufferedReader(envFile, StandardCharsets.UTF_8)) {
      return new Gson().fromJson(reader, new TypeToken<Map<String, Map<String, String>>>() {
      }.getType());
    }
  }

  private String getZKNamespace(String runnableName) {
    return String.format("/%s/runnables/%s", runId.getId(), runnableName);
  }

  /**
   * Attempts to change the number of running instances.
   *
   * @return {@code true} if the message does requests for changes in number of running instances of a runnable,
   * {@code false} otherwise.
   */
  private boolean handleSetInstances(Message message, Runnable completion) {
    if (message.getType() != Message.Type.SYSTEM || message.getScope() != Message.Scope.RUNNABLE) {
      return false;
    }

    Command command = message.getCommand();
    Map<String, String> options = command.getOptions();
    if (!"instances".equals(command.getCommand()) || !options.containsKey("count")) {
      return false;
    }

    final String runnableName = message.getRunnableName();
    if (runnableName == null || runnableName.isEmpty() || !twillSpec.getRunnables().containsKey(runnableName)) {
      LOG.info("Unknown runnable {}", runnableName);
      return false;
    }

    final int newCount = Integer.parseInt(options.get("count"));
    final int oldCount = expectedContainers.getExpected(runnableName);

    LOG.info("Received change instances request for {}, from {} to {}.", runnableName, oldCount, newCount);

    if (newCount == oldCount) {   // Nothing to do, simply complete the request.
      completion.run();
      return true;
    }

    instanceChangeExecutor.execute(createSetInstanceRunnable(message, completion, oldCount, newCount));
    return true;
  }

  /**
   * Creates a Runnable for execution of change instance request.
   */
  private Runnable createSetInstanceRunnable(final Message message, final Runnable completion,
                                             final int oldCount, final int newCount) {
    return new Runnable() {
      @Override
      public void run() {
        final String runnableName = message.getRunnableName();

        LOG.info("Processing change instance request for {}, from {} to {}.", runnableName, oldCount, newCount);
        try {
          // Wait until running container count is the same as old count
          runningContainers.waitForCount(runnableName, oldCount);
          LOG.info("Confirmed {} containers running for {}.", oldCount, runnableName);

          expectedContainers.setExpected(runnableName, newCount);

          try {
            if (newCount < oldCount) {
              // Shutdown some running containers
              for (int i = 0; i < oldCount - newCount; i++) {
                runningContainers.stopLastAndWait(runnableName);
              }
            } else {
              // Increase the number of instances
              runnableContainerRequests.add(createRunnableContainerRequest(runnableName, newCount - oldCount));
            }
          } finally {
            // Send a message to all running runnables that number of instances have changed
            runningContainers.sendToRunnable(runnableName, message, completion);
            LOG.info("Change instances request completed. From {} to {}.", oldCount, newCount);
          }
        } catch (InterruptedException e) {
          // If the wait is being interrupted, discard the message.
          completion.run();
        }
      }
    };
  }

  private RunnableContainerRequest createRunnableContainerRequest(final String runnableName,
                                                                  final int numberOfInstances) {
    return createRunnableContainerRequest(runnableName, numberOfInstances, true);
  }

  private RunnableContainerRequest createRunnableContainerRequest(final String runnableName,
                                                                  final int numberOfInstances,
                                                                  final boolean isProvisioned) {
    // Find the current order of the given runnable in order to create a RunnableContainerRequest.
    TwillSpecification.Order order = Iterables.find(twillSpec.getOrders(), new Predicate<TwillSpecification.Order>() {
      @Override
      public boolean apply(TwillSpecification.Order input) {
        return (input.getNames().contains(runnableName));
      }
    });

    RuntimeSpecification runtimeSpec = twillSpec.getRunnables().get(runnableName);
    Resource capability = createCapability(runtimeSpec.getResourceSpecification());
    Map<AllocationSpecification, Collection<RuntimeSpecification>> requestsMap = Maps.newHashMap();

    if (placementPolicyManager.getDistributedRunnables().contains(runnableName)) {
      for (int instanceId = 0; instanceId < numberOfInstances; instanceId++) {
        AllocationSpecification allocationSpecification =
          new AllocationSpecification(capability, AllocationSpecification.Type.ALLOCATE_ONE_INSTANCE_AT_A_TIME,
                                      runnableName, instanceId);
        addAllocationSpecification(allocationSpecification, requestsMap, runtimeSpec);
      }
    } else {
      AllocationSpecification allocationSpecification;
      if (numberOfInstances > 1) {
        allocationSpecification = new AllocationSpecification(capability);
      } else {
        // for a single instance, we always insert ALLOCATE_ONE_INSTANCE_AT_A_TIME. for multi-instance
        // runnables, this case occurs during retries.
        allocationSpecification = new AllocationSpecification(capability,
          AllocationSpecification.Type.ALLOCATE_ONE_INSTANCE_AT_A_TIME, runnableName, 0);
      }
      addAllocationSpecification(allocationSpecification, requestsMap, runtimeSpec);
    }
    return new RunnableContainerRequest(order.getType(), requestsMap, isProvisioned);
  }

  private Runnable getMessageCompletion(final String messageId, final SettableFuture<String> future) {
    return new Runnable() {
      @Override
      public void run() {
        future.set(messageId);
      }
    };
  }

  private Resource createCapability(ResourceSpecification resourceSpec) {
    Resource capability = Records.newRecord(Resource.class);

    if (!YarnUtils.setVirtualCores(capability, resourceSpec.getVirtualCores())) {
      LOG.debug("Virtual cores limit not supported.");
    }

    capability.setMemory(resourceSpec.getMemorySize());
    return capability;
  }

  /**
   * Attempt to restart some instances from a runnable or some runnables.
   *
   * @return {@code true} if the message requests restarting some instances and {@code false} otherwise.
   */
  private boolean handleRestartRunnablesInstances(Message message, Runnable completion) {
    LOG.debug("Check if it should process a restart runnable instances.");

    if (message.getType() != Message.Type.SYSTEM) {
      return false;
    }

    Message.Scope messageScope = message.getScope();
    if (messageScope != Message.Scope.RUNNABLE && messageScope != Message.Scope.RUNNABLES) {
      return false;
    }

    Command requestCommand = message.getCommand();
    if (!Constants.RESTART_ALL_RUNNABLE_INSTANCES.equals(requestCommand.getCommand()) &&
      !Constants.RESTART_RUNNABLES_INSTANCES.equals(requestCommand.getCommand())) {
      return false;
    }

    LOG.debug("Processing restart runnable instances message {}.", message);

    if (!Strings.isNullOrEmpty(message.getRunnableName()) && message.getScope() == Message.Scope.RUNNABLE) {
      // ... for a runnable ...
      String runnableName = message.getRunnableName();
      LOG.debug("Start restarting all runnable {} instances.", runnableName);
      restartRunnableInstances(runnableName, null, completion);
    } else {
      // ... or maybe some runnables
      for (Map.Entry<String, String> option : requestCommand.getOptions().entrySet()) {
        String runnableName = option.getKey();
        Set<Integer> restartedInstanceIds = GSON.fromJson(option.getValue(),
                                                          new TypeToken<Set<Integer>>() {
                                                          }.getType());

        LOG.debug("Start restarting runnable {} instances {}", runnableName, restartedInstanceIds);
        restartRunnableInstances(runnableName, restartedInstanceIds, completion);
      }
    }

    return true;
  }

  /**
   * Helper method to restart instances of runnables.
   */
  private void restartRunnableInstances(final String runnableName, @Nullable final Set<Integer> instanceIds,
                                        final Runnable completion) {
    instanceChangeExecutor.execute(new Runnable() {
      @Override
      public void run() {
        LOG.debug("Begin restart runnable {} instances.", runnableName);
        int runningCount = runningContainers.count(runnableName);
        Set<Integer> instancesToRemove = instanceIds == null ? null : ImmutableSet.copyOf(instanceIds);
        if (instancesToRemove == null) {
          instancesToRemove = Ranges.closedOpen(0, runningCount).asSet(DiscreteDomains.integers());
        }

        LOG.info("Restarting instances {} for runnable {}", instancesToRemove, runnableName);
        RunnableContainerRequest containerRequest =
          createRunnableContainerRequest(runnableName, instancesToRemove.size(), false);
        runnableContainerRequests.add(containerRequest);

        for (int instanceId : instancesToRemove) {
          LOG.debug("Stop instance {} for runnable {}", instanceId, runnableName);
          try {
            runningContainers.stopByIdAndWait(runnableName, instanceId);
          } catch (Exception ex) {
            // could be thrown if the container already stopped.
            LOG.info("Exception thrown when stopping instance {} probably already stopped.", instanceId);
          }
        }

        LOG.info("All instances in {} for runnable {} are stopped. Ready to provision",
                 instancesToRemove, runnableName);

        // set the container request to be ready
        containerRequest.setReadyToBeProvisioned();

        // For all runnables that needs to re-request for containers, update the expected count timestamp
        // so that the EventHandler would be triggered with the right expiration timestamp.
        expectedContainers.updateRequestTime(Collections.singleton(runnableName));

        completion.run();
      }
    });
  }

  /**
   * Attempt to change the log level from a runnable or all runnables.
   *
   * @return {@code true} if the message requests changing log levels and {@code false} otherwise.
   */
  private boolean handleLogLevelMessages(Message message, Runnable completion) {
    Message.Scope scope = message.getScope();
    if (message.getType() != Message.Type.SYSTEM ||
      (scope != Message.Scope.RUNNABLE && scope != Message.Scope.ALL_RUNNABLE)) {
      return false;
    }

    String command = message.getCommand().getCommand();
    if (!command.equals(SystemMessages.SET_LOG_LEVEL) && !command.equals(SystemMessages.RESET_LOG_LEVEL)) {
      return false;
    }

    if (scope == Message.Scope.ALL_RUNNABLE) {
      runningContainers.sendToAll(message, completion);
    } else {
      final String runnableName = message.getRunnableName();
      if (runnableName == null || !twillSpec.getRunnables().containsKey(runnableName)) {
        LOG.info("Unknown runnable {}", runnableName);
        return false;
      }
      runningContainers.sendToRunnable(runnableName, message, completion);
    }
    return true;
  }
}