/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.dag.app;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.Objects;

import com.google.common.collect.Maps;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.ipc.ProtocolSignature;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.authorize.PolicyProvider;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.tez.common.ContainerContext;
import org.apache.tez.common.ContainerTask;
import org.apache.tez.common.TezConverterUtils;
import org.apache.tez.common.TezLocalResource;
import org.apache.tez.common.TezTaskUmbilicalProtocol;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.common.security.JobTokenSecretManager;
import org.apache.tez.common.security.TokenCache;
import org.apache.tez.serviceplugins.api.ContainerEndReason;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskCommunicator;
import org.apache.tez.serviceplugins.api.TaskCommunicatorContext;
import org.apache.tez.serviceplugins.api.TaskHeartbeatRequest;
import org.apache.tez.serviceplugins.api.TaskHeartbeatResponse;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.runtime.api.impl.TezHeartbeatRequest;
import org.apache.tez.runtime.api.impl.TezHeartbeatResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.Private
public class TezTaskCommunicatorImpl extends TaskCommunicator {

  private static final Logger LOG = LoggerFactory.getLogger(TezTaskCommunicatorImpl.class);

  private static final ContainerTask TASK_FOR_INVALID_JVM = new ContainerTask(
      null, true, null, null, false);

  private final TezTaskUmbilicalProtocol taskUmbilical;

  protected final ConcurrentMap<ContainerId, ContainerInfo> registeredContainers =
      new ConcurrentHashMap<>();
  protected final ConcurrentMap<TezTaskAttemptID, ContainerId> attemptToContainerMap =
      new ConcurrentHashMap<>();


  protected final String tokenIdentifier;
  protected final Token<JobTokenIdentifier> sessionToken;
  protected final Configuration conf;
  protected InetSocketAddress address;

  protected volatile Server server;

  public static final class ContainerInfo {

    ContainerInfo(ContainerId containerId, String host, int port) {
      this.containerId = containerId;
      this.host = host;
      this.port = port;
    }

    final ContainerId containerId;
    public final String host;
    public final int port;
    TezHeartbeatResponse lastResponse = null;
    TaskSpec taskSpec = null;
    long lastRequestId = 0;
    Map<String, LocalResource> additionalLRs = null;
    Credentials credentials = null;
    boolean credentialsChanged = false;
    boolean taskPulled = false;

    void reset() {
      taskSpec = null;
      additionalLRs = null;
      credentials = null;
      credentialsChanged = false;
      taskPulled = false;
    }
  }



  /**
   * Construct the service.
   */
  public TezTaskCommunicatorImpl(TaskCommunicatorContext taskCommunicatorContext) {
    super(taskCommunicatorContext);
    this.taskUmbilical = new TezTaskUmbilicalProtocolImpl();
    this.tokenIdentifier = taskCommunicatorContext.getApplicationAttemptId().getApplicationId().toString();
    this.sessionToken = TokenCache.getSessionToken(taskCommunicatorContext.getAMCredentials());
    try {
      conf = TezUtils.createConfFromUserPayload(getContext().getInitialUserPayload());
    } catch (IOException e) {
      throw new TezUncheckedException(
          "Unable to parse user payload for " + TezTaskCommunicatorImpl.class.getSimpleName(), e);
    }
  }

  @Override
  public void start() {
    startRpcServer();
  }

  @Override
  public void shutdown() {
    stopRpcServer();
  }

  protected void startRpcServer() {
    try {
      JobTokenSecretManager jobTokenSecretManager =
          new JobTokenSecretManager();
      jobTokenSecretManager.addTokenForJob(tokenIdentifier, sessionToken);

      server = new RPC.Builder(conf)
          .setProtocol(TezTaskUmbilicalProtocol.class)
          .setBindAddress("0.0.0.0")
          .setPort(0)
          .setInstance(taskUmbilical)
          .setNumHandlers(
              conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT,
                  TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT))
          .setPortRangeConfig(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE)
          .setSecretManager(jobTokenSecretManager).build();

      // Enable service authorization?
      if (conf.getBoolean(
          CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION,
          false)) {
        refreshServiceAcls(conf, new TezAMPolicyProvider());
      }

      server.start();
      InetSocketAddress serverBindAddress = NetUtils.getConnectAddress(server);
      this.address = NetUtils.createSocketAddrForHost(
          serverBindAddress.getAddress().getCanonicalHostName(),
          serverBindAddress.getPort());
      LOG.info("Instantiated TezTaskCommunicator RPC at " + this.address);
    } catch (IOException e) {
      throw new TezUncheckedException(e);
    }
  }

  protected void stopRpcServer() {
    if (server != null) {
      server.stop();
      server = null;
    }
  }

  protected Configuration getConf() {
    return this.conf;
  }

  private void refreshServiceAcls(Configuration configuration,
                                  PolicyProvider policyProvider) {
    this.server.refreshServiceAcl(configuration, policyProvider);
  }

  @Override
  public void registerRunningContainer(ContainerId containerId, String host, int port) {
    ContainerInfo oldInfo = registeredContainers.putIfAbsent(containerId,
        new ContainerInfo(containerId, host, port));
    if (oldInfo != null) {
      throw new TezUncheckedException("Multiple registrations for containerId: " + containerId);
    }
  }

  @Override
  public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason, String diagnostics) {
    ContainerInfo containerInfo = registeredContainers.remove(containerId);
    if (containerInfo != null) {
      synchronized(containerInfo) {
        if (containerInfo.taskSpec != null && containerInfo.taskSpec.getTaskAttemptID() != null) {
          attemptToContainerMap.remove(containerInfo.taskSpec.getTaskAttemptID());
        }
      }
    }
  }

  @Override
  public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec,
                                         Map<String, LocalResource> additionalResources,
                                         Credentials credentials, boolean credentialsChanged,
                                         int priority) {

    ContainerInfo containerInfo = registeredContainers.get(containerId);
    Objects.requireNonNull(containerInfo,
        String.format("Cannot register task attempt %s to unknown container %s",
            taskSpec.getTaskAttemptID(), containerId));
    synchronized (containerInfo) {
      if (containerInfo.taskSpec != null) {
        throw new TezUncheckedException(
            "Cannot register task: " + taskSpec.getTaskAttemptID() + " to container: " +
                containerId + " , with pre-existing assignment: " +
                containerInfo.taskSpec.getTaskAttemptID());
      }
      containerInfo.taskSpec = taskSpec;
      containerInfo.additionalLRs = additionalResources;
      containerInfo.credentials = credentials;
      containerInfo.credentialsChanged = credentialsChanged;
      containerInfo.taskPulled = false;

      ContainerId oldId = attemptToContainerMap.putIfAbsent(taskSpec.getTaskAttemptID(), containerId);
      if (oldId != null) {
        throw new TezUncheckedException(
            "Attempting to register an already registered taskAttempt with id: " +
                taskSpec.getTaskAttemptID() + " to containerId: " + containerId +
                ". Already registered to containerId: " + oldId);
      }
    }
  }


  @Override
  public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID, TaskAttemptEndReason endReason, String diagnostics) {
    ContainerId containerId = attemptToContainerMap.remove(taskAttemptID);
    if(containerId == null) {
      LOG.warn("Unregister task attempt: " + taskAttemptID + " from unknown container");
      return;
    }
    ContainerInfo containerInfo = registeredContainers.get(containerId);
    if (containerInfo == null) {
      LOG.warn("Unregister task attempt: " + taskAttemptID +
          " from non-registered container: " + containerId);
      return;
    }
    synchronized (containerInfo) {
      containerInfo.reset();
      attemptToContainerMap.remove(taskAttemptID);
    }
  }

  @Override
  public InetSocketAddress getAddress() {
    return address;
  }

  @Override
  public void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
    // Empty. Not registering, or expecting any updates.
  }

  @Override
  public void dagComplete(int dagIdentifier) {
    // Nothing to do at the moment. Some of the TODOs from TaskAttemptListener apply here.
  }

  @Override
  public Object getMetaInfo() {
    return address;
  }

  protected String getTokenIdentifier() {
    return tokenIdentifier;
  }

  protected Token<JobTokenIdentifier> getSessionToken() {
    return sessionToken;
  }

  public TezTaskUmbilicalProtocol getUmbilical() {
    return this.taskUmbilical;
  }

  private class TezTaskUmbilicalProtocolImpl implements TezTaskUmbilicalProtocol {

    @Override
    public ContainerTask getTask(ContainerContext containerContext) throws IOException {
      ContainerTask task = null;
      if (containerContext == null || containerContext.getContainerIdentifier() == null) {
        LOG.info("Invalid task request with an empty containerContext or containerId");
        task = TASK_FOR_INVALID_JVM;
      } else {
        ContainerId containerId = ConverterUtils.toContainerId(containerContext
            .getContainerIdentifier());
        if (LOG.isDebugEnabled()) {
          LOG.debug("Container with id: " + containerId + " asked for a task");
        }
        task = getContainerTask(containerId);
        if (task != null && !task.shouldDie()) {
          getContext().taskSubmitted(task.getTaskSpec().getTaskAttemptID(), containerId);
          getContext().taskStartedRemotely(task.getTaskSpec().getTaskAttemptID());
        }
      }
      if (LOG.isDebugEnabled()) {
        LOG.debug("getTask returning task: " + task);
      }
      return task;
    }

    @Override
    public boolean canCommit(TezTaskAttemptID taskAttemptId) throws IOException {
      return getContext().canCommit(taskAttemptId);
    }

    @Override
    public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) throws IOException,
        TezException {
      ContainerId containerId = ConverterUtils.toContainerId(request.getContainerIdentifier());
      long requestId = request.getRequestId();
      if (LOG.isDebugEnabled()) {
        LOG.debug("Received heartbeat from container"
            + ", request=" + request);
      }

      ContainerInfo containerInfo = registeredContainers.get(containerId);
      if (containerInfo == null) {
        LOG.warn("Received task heartbeat from unknown container with id: " + containerId +
            ", asking it to die");
        TezHeartbeatResponse response = new TezHeartbeatResponse();
        response.setLastRequestId(requestId);
        response.setShouldDie();
        return response;
      }

      synchronized (containerInfo) {
        if (containerInfo.lastRequestId == requestId) {
          LOG.warn("Old sequenceId received: " + requestId
              + ", Re-sending last response to client");
          return containerInfo.lastResponse;
        }
      }



      TezHeartbeatResponse response = new TezHeartbeatResponse();
      TezTaskAttemptID taskAttemptID = request.getCurrentTaskAttemptID();
      if (taskAttemptID != null) {
        TaskHeartbeatResponse tResponse;
        synchronized (containerInfo) {
          ContainerId containerIdFromMap = attemptToContainerMap.get(taskAttemptID);
          if (containerIdFromMap == null || !containerIdFromMap.equals(containerId)) {
            throw new TezException("Attempt " + taskAttemptID
                + " is not recognized for heartbeat");
          }

          if (containerInfo.lastRequestId + 1 != requestId) {
            throw new TezException("Container " + containerId
                + " has invalid request id. Expected: "
                + containerInfo.lastRequestId + 1
                + " and actual: " + requestId);
          }
        }
        TaskHeartbeatRequest tRequest = new TaskHeartbeatRequest(request.getContainerIdentifier(),
            request.getCurrentTaskAttemptID(), request.getEvents(), request.getStartIndex(),
            request.getPreRoutedStartIndex(), request.getMaxEvents());
        tResponse = getContext().heartbeat(tRequest);
        response.setEvents(tResponse.getEvents());
        response.setNextFromEventId(tResponse.getNextFromEventId());
        response.setNextPreRoutedEventId(tResponse.getNextPreRoutedEventId());
      }
      response.setLastRequestId(requestId);
      containerInfo.lastRequestId = requestId;
      containerInfo.lastResponse = response;
      return response;
    }


    // TODO Remove this method once we move to the Protobuf RPC engine
    @Override
    public long getProtocolVersion(String protocol, long clientVersion) throws IOException {
      return versionID;
    }

    // TODO Remove this method once we move to the Protobuf RPC engine
    @Override
    public ProtocolSignature getProtocolSignature(String protocol, long clientVersion,
                                                  int clientMethodsHash) throws IOException {
      return ProtocolSignature.getProtocolSignature(this, protocol,
          clientVersion, clientMethodsHash);
    }
  }

  private ContainerTask getContainerTask(ContainerId containerId) throws IOException {
    ContainerInfo containerInfo = registeredContainers.get(containerId);
    ContainerTask task;
    if (containerInfo == null) {
      if (getContext().isKnownContainer(containerId)) {
        LOG.info("Container with id: " + containerId
            + " is valid, but no longer registered, and will be killed");
      } else {
        LOG.info("Container with id: " + containerId
            + " is invalid and will be killed");
      }
      task = TASK_FOR_INVALID_JVM;
    } else {
      synchronized (containerInfo) {
        getContext().containerAlive(containerId);
        if (containerInfo.taskSpec != null) {
          if (!containerInfo.taskPulled) {
            containerInfo.taskPulled = true;
            task = constructContainerTask(containerInfo);
          } else {
            if (LOG.isDebugEnabled()) {
              LOG.debug("Task " + containerInfo.taskSpec.getTaskAttemptID() +
                  " already sent to container: " + containerId);
            }
            task = null;
          }
        } else {
          task = null;
          if (LOG.isDebugEnabled()) {
            LOG.debug("No task assigned yet for running container: " + containerId);
          }
        }
      }
    }
    return task;
  }

  private ContainerTask constructContainerTask(ContainerInfo containerInfo) throws IOException {
    return new ContainerTask(containerInfo.taskSpec, false,
        convertLocalResourceMap(containerInfo.additionalLRs), containerInfo.credentials,
        containerInfo.credentialsChanged);
  }

  private Map<String, TezLocalResource> convertLocalResourceMap(Map<String, LocalResource> ylrs)
      throws IOException {
    Map<String, TezLocalResource> tlrs = Maps.newHashMap();
    if (ylrs != null) {
      for (Map.Entry<String, LocalResource> ylrEntry : ylrs.entrySet()) {
        TezLocalResource tlr;
        try {
          tlr = TezConverterUtils.convertYarnLocalResourceToTez(ylrEntry.getValue());
        } catch (URISyntaxException e) {
          throw new IOException(e);
        }
        tlrs.put(ylrEntry.getKey(), tlr);
      }
    }
    return tlrs;
  }

  protected ContainerInfo getContainerInfo(ContainerId containerId) {
    return registeredContainers.get(containerId);
  }

  protected ContainerId getContainerForAttempt(TezTaskAttemptID taskAttemptId) {
    return attemptToContainerMap.get(taskAttemptId);
  }
}