/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tajo.querymaster;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.util.RackResolver;
import org.apache.tajo.TaskAttemptId;
import org.apache.tajo.conf.TajoConf;
import org.apache.tajo.engine.planner.global.ExecutionBlock;
import org.apache.tajo.engine.planner.global.MasterPlan;
import org.apache.tajo.engine.query.TaskRequest;
import org.apache.tajo.engine.query.TaskRequestImpl;
import org.apache.tajo.exception.TajoInternalError;
import org.apache.tajo.ipc.QueryCoordinatorProtocol;
import org.apache.tajo.ipc.QueryCoordinatorProtocol.QueryCoordinatorProtocolService;
import org.apache.tajo.ipc.TajoWorkerProtocol;
import org.apache.tajo.master.cluster.WorkerConnectionInfo;
import org.apache.tajo.master.event.*;
import org.apache.tajo.master.event.TaskAttemptToSchedulerEvent.TaskAttemptScheduleContext;
import org.apache.tajo.master.event.TaskSchedulerEvent.EventType;
import org.apache.tajo.plan.serder.LogicalNodeSerializer;
import org.apache.tajo.resource.NodeResource;
import org.apache.tajo.resource.NodeResources;
import org.apache.tajo.rpc.AsyncRpcClient;
import org.apache.tajo.rpc.CallFuture;
import org.apache.tajo.rpc.NettyClientBase;
import org.apache.tajo.rpc.RpcClientManager;
import org.apache.tajo.service.ServiceTracker;
import org.apache.tajo.storage.DataLocation;
import org.apache.tajo.storage.fragment.Fragment;
import org.apache.tajo.util.NetUtils;
import org.apache.tajo.util.RpcParameterFactory;
import org.apache.tajo.util.TUtil;

import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.util.*;
import java.util.Map.Entry;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.apache.tajo.ResourceProtos.*;

public class DefaultTaskScheduler extends AbstractTaskScheduler {
  private static final Log LOG = LogFactory.getLog(DefaultTaskScheduler.class);

  private final TaskSchedulerContext context;
  private Stage stage;
  private TajoConf tajoConf;
  private Properties rpcParams;

  private Thread schedulingThread;
  private volatile boolean isStopped;
  private AtomicBoolean needWakeup = new AtomicBoolean();

  private ScheduledRequests scheduledRequests;

  private int minTaskMemory;
  private int nextTaskId = 0;
  private int scheduledObjectNum = 0;
  private boolean isLeaf;
  private int schedulerDelay;
  private int maximumRequestContainer;

  // candidate workers for locality of high priority
  private Set<Integer> candidateWorkers = Sets.newHashSet();

  public DefaultTaskScheduler(TaskSchedulerContext context, Stage stage) {
    super(DefaultTaskScheduler.class.getName());
    this.context = context;
    this.stage = stage;
  }

  @Override
  public void init(Configuration conf) {
    tajoConf = TUtil.checkTypeAndGet(conf, TajoConf.class);
    rpcParams = RpcParameterFactory.get(tajoConf);

    scheduledRequests = new ScheduledRequests();
    minTaskMemory = tajoConf.getIntVar(TajoConf.ConfVars.TASK_RESOURCE_MINIMUM_MEMORY);
    schedulerDelay= tajoConf.getIntVar(TajoConf.ConfVars.QUERYMASTER_TASK_SCHEDULER_DELAY);
    isLeaf = stage.getMasterPlan().isLeaf(stage.getBlock());

    this.schedulingThread = new Thread() {
      public void run() {

        while (!isStopped && !Thread.currentThread().isInterrupted()) {

          try {
            schedule();
          } catch (InterruptedException e) {
            if (isStopped) {
              break;
            } else {
              LOG.fatal(e.getMessage(), e);
              stage.abort(StageState.ERROR, e);
            }
          } catch (Throwable e) {
            LOG.fatal(e.getMessage(), e);
            stage.abort(StageState.ERROR, e);
            break;
          }
        }
        info(LOG, "TaskScheduler schedulingThread stopped");
      }
    };
    super.init(conf);
  }

  @Override
  public void start() {
    info(LOG, "Start TaskScheduler");
    maximumRequestContainer = Math.min(tajoConf.getIntVar(TajoConf.ConfVars.QUERYMASTER_TASK_SCHEDULER_REQUEST_MAX_NUM)
        , stage.getContext().getWorkerMap().size());

    if (isLeaf) {
      candidateWorkers.addAll(getWorkerIds(getLeafTaskHosts()));
    } else {
      //find assigned hosts for Non-Leaf locality in children executionBlock
      List<ExecutionBlock> executionBlockList = stage.getMasterPlan().getChilds(stage.getBlock());
      for (ExecutionBlock executionBlock : executionBlockList) {
        Stage childStage = stage.getContext().getStage(executionBlock.getId());
        candidateWorkers.addAll(childStage.getAssignedWorkerMap().keySet());
      }
    }

    this.schedulingThread.start();
    super.start();
  }

  @Override
  public void stop() {
    isStopped = true;

    if (schedulingThread != null) {
      synchronized (schedulingThread) {
        schedulingThread.interrupt();
      }
    }
    candidateWorkers.clear();
    scheduledRequests.clear();
    info(LOG, "Task Scheduler stopped");
    super.stop();
  }

  protected void info(Log log, String message) {
    log.info(String.format("[%s] %s", stage.getId(), message));
  }

  protected void warn(Log log, String message) {
    log.warn(String.format("[%s] %s", stage.getId(), message));
  }

  private Fragment[] fragmentsForNonLeafTask;
  private Fragment[] broadcastFragmentsForNonLeafTask;

  public void schedule() throws Exception {
    try {
      final int incompleteTaskNum = scheduledRequests.leafTaskNum() + scheduledRequests.nonLeafTaskNum();
      if (incompleteTaskNum == 0) {
        needWakeup.set(true);
        // all task is done or tasks is not scheduled
        synchronized (schedulingThread) {
          schedulingThread.wait(1000);
        }
      } else {
        LinkedList<TaskRequestEvent> taskRequests = createTaskRequest(incompleteTaskNum);

        if (taskRequests.size() == 0) {
          synchronized (schedulingThread) {
            schedulingThread.wait(schedulerDelay);
          }
        } else {
          if (LOG.isDebugEnabled()) {
            LOG.debug("Get " + taskRequests.size() + " taskRequestEvents ");
          }

          if (isLeaf) {
            scheduledRequests.assignToLeafTasks(taskRequests);
          } else {
            scheduledRequests.assignToNonLeafTasks(taskRequests);
          }
        }
      }
    } catch (TimeoutException e) {
      LOG.error(e.getMessage());
    }
  }

  @Override
  public void handle(TaskSchedulerEvent event) {
    if (event.getType() == EventType.T_SCHEDULE) {
      if (event instanceof FragmentScheduleEvent) {
        FragmentScheduleEvent castEvent = (FragmentScheduleEvent) event;
        if (context.isLeafQuery()) {
          TaskAttemptScheduleContext taskContext = new TaskAttemptScheduleContext();
          Task task = Stage.newEmptyTask(context, taskContext, stage, nextTaskId++);
          task.addFragment(castEvent.getLeftFragment(), true);
          scheduledObjectNum++;
          if (castEvent.hasRightFragments()) {
            task.addFragments(castEvent.getRightFragments());
          }
          stage.getEventHandler().handle(new TaskEvent(task.getId(), TaskEventType.T_SCHEDULE));
        } else {
          fragmentsForNonLeafTask = new Fragment[2];
          fragmentsForNonLeafTask[0] = castEvent.getLeftFragment();
          if (castEvent.hasRightFragments()) {
            Collection<Fragment> var = castEvent.getRightFragments();
            Fragment[] rightFragments = var.toArray(new Fragment[var.size()]);
            fragmentsForNonLeafTask[1] = rightFragments[0];
            if (rightFragments.length > 1) {
              broadcastFragmentsForNonLeafTask = new Fragment[rightFragments.length - 1];
              System.arraycopy(rightFragments, 1, broadcastFragmentsForNonLeafTask, 0, broadcastFragmentsForNonLeafTask.length);
            } else {
              broadcastFragmentsForNonLeafTask = null;
            }
          }
        }
      } else if (event instanceof FetchScheduleEvent) {
        FetchScheduleEvent castEvent = (FetchScheduleEvent) event;
        Map<String, List<FetchProto>> fetches = castEvent.getFetches();
        TaskAttemptScheduleContext taskScheduleContext = new TaskAttemptScheduleContext();
        Task task = Stage.newEmptyTask(context, taskScheduleContext, stage, nextTaskId++);
        scheduledObjectNum++;
        for (Entry<String, List<FetchProto>> eachFetch : fetches.entrySet()) {
          task.addFetches(eachFetch.getKey(), eachFetch.getValue());
          task.addFragment(fragmentsForNonLeafTask[0], true);
          if (fragmentsForNonLeafTask[1] != null) {
            task.addFragment(fragmentsForNonLeafTask[1], true);
          }
        }
        if (broadcastFragmentsForNonLeafTask != null && broadcastFragmentsForNonLeafTask.length > 0) {
          task.addFragments(Arrays.asList(broadcastFragmentsForNonLeafTask));
        }
        stage.getEventHandler().handle(new TaskEvent(task.getId(), TaskEventType.T_SCHEDULE));
      } else if (event instanceof TaskAttemptToSchedulerEvent) {
        TaskAttemptToSchedulerEvent castEvent = (TaskAttemptToSchedulerEvent) event;
        if (context.isLeafQuery()) {
          scheduledRequests.addLeafTask(castEvent);
        } else {
          scheduledRequests.addNonLeafTask(castEvent);
        }

        if (needWakeup.getAndSet(false)) {
          //wake up scheduler thread after scheduled
          synchronized (schedulingThread) {
            schedulingThread.notifyAll();
          }
        }
      }
    } else if (event.getType() == EventType.T_SCHEDULE_CANCEL) {
      // when a stage is killed, unassigned query unit attmpts are canceled from the scheduler.
      // This event is triggered by TaskAttempt.
      TaskAttemptToSchedulerEvent castedEvent = (TaskAttemptToSchedulerEvent) event;
      scheduledRequests.leafTasks.remove(castedEvent.getTaskAttempt().getId());
      LOG.info(castedEvent.getTaskAttempt().getId() + " is canceled from " + this.getClass().getSimpleName());
      ((TaskAttemptToSchedulerEvent) event).getTaskAttempt().handle(
          new TaskAttemptEvent(castedEvent.getTaskAttempt().getId(), TaskAttemptEventType.TA_SCHEDULE_CANCELED));
    }
  }

  private Set<Integer> getWorkerIds(Collection<String> hosts){
    Set<Integer> workerIds = Sets.newHashSet();
    if(hosts.isEmpty()) return workerIds;

    for (WorkerConnectionInfo worker : stage.getContext().getWorkerMap().values()) {
      if(hosts.contains(worker.getHost())){
        workerIds.add(worker.getId());
      }
    }
    return workerIds;
  }


  protected LinkedList<TaskRequestEvent> createTaskRequest(final int incompleteTaskNum) throws Exception {
    LinkedList<TaskRequestEvent> taskRequestEvents = new LinkedList<>();

    //If scheduled tasks is long-term task, cluster resource can be the worst load balance.
    //This part is to throttle the maximum required container per request
    int requestContainerNum = Math.min(incompleteTaskNum, maximumRequestContainer);
    if (LOG.isDebugEnabled()) {
      LOG.debug("Try to schedule task resources: " + requestContainerNum);
    }

    ServiceTracker serviceTracker =
        context.getMasterContext().getQueryMasterContext().getWorkerContext().getServiceTracker();
    NettyClientBase tmClient = RpcClientManager.getInstance().
        getClient(serviceTracker.getUmbilicalAddress(), QueryCoordinatorProtocol.class, true, rpcParams);
    QueryCoordinatorProtocolService masterClientService = tmClient.getStub();

    CallFuture<NodeResourceResponse> callBack = new CallFuture<>();
    NodeResourceRequest.Builder request = NodeResourceRequest.newBuilder();
    request.setCapacity(NodeResources.createResource(minTaskMemory).getProto())
        .setNumContainers(requestContainerNum)
        .setPriority(stage.getPriority())
        .setQueryId(context.getMasterContext().getQueryId().getProto())
        .setType(isLeaf ? ResourceType.LEAF : ResourceType.INTERMEDIATE)
        .setUserId(context.getMasterContext().getQueryContext().getUser())
        .setRunningTasks(stage.getTotalScheduledObjectsCount() - stage.getCompletedTaskCount())
        .addAllCandidateNodes(candidateWorkers)
        .setQueue(context.getMasterContext().getQueryContext().get("queue", "default")); //TODO set queue

    masterClientService.reserveNodeResources(callBack.getController(), request.build(), callBack);
    NodeResourceResponse response = callBack.get();

    for (AllocationResourceProto resource : response.getResourceList()) {
      taskRequestEvents.add(new TaskRequestEvent(resource.getWorkerId(), resource, context.getBlockId()));
    }

    return taskRequestEvents;
  }

  @Override
  public int remainingScheduledObjectNum() {
    return scheduledObjectNum;
  }

  public void releaseTaskAttempt(TaskAttempt taskAttempt) {
    if (taskAttempt != null && taskAttempt.isLeafTask() && taskAttempt.getWorkerConnectionInfo() != null) {

      HostVolumeMapping mapping =
          scheduledRequests.leafTaskHostMapping.get(taskAttempt.getWorkerConnectionInfo().getHost());
      if (mapping != null && mapping.lastAssignedVolumeId.containsKey(taskAttempt.getId())) {
        mapping.decreaseConcurrency(mapping.lastAssignedVolumeId.remove(taskAttempt.getId()));
      }
    }
  }
  /**
   * One worker can have multiple running task runners. <code>HostVolumeMapping</code>
   * describes various information for one worker, including :
   * <ul>
   *  <li>host name</li>
   *  <li>rack name</li>
   *  <li>unassigned tasks for each disk volume</li>
   *  <li>last assigned volume id - it can be used for assigning task in a round-robin manner</li>
   *  <li>the number of running tasks for each volume</li>
   * </ul>, each task runner and the concurrency number of running tasks for volumes.
   *
   * Here, we identifier a task runner by {@link ContainerId}, and we use volume ids to identify
   * all disks in this node. Actually, each volume is only used to distinguish disks, and we don't
   * know a certain volume id indicates a certain disk. If you want to know volume id, please read the below section.
   *
   * <h3>Volume id</h3>
   * Volume id is an integer. Each volume id identifies each disk volume.
   *
   * This volume id can be obtained from org.apache.hadoop.fs.BlockStorageLocation#getVolumeIds()}.   *
   * HDFS cannot give any volume id due to unknown reason and disabled config 'dfs.client.file-block-locations.enabled'.
   * In this case, the volume id will be -1 or other native integer.
   *
   * <h3>See Also</h3>
   * <ul>
   *   <li>HDFS-3672 (https://issues.apache.org/jira/browse/HDFS-3672).</li>
   * </ul>
   */
  public class HostVolumeMapping {
    private final String host;
    private final String rack;
    /** A key is disk volume, and a value is a list of tasks to be scheduled. */
    private Map<Integer, LinkedHashSet<TaskAttempt>> unassignedTaskForEachVolume =
        Collections.synchronizedMap(new HashMap<>());
    /** A value is last assigned volume id for each task runner */
    private HashMap<TaskAttemptId, Integer> lastAssignedVolumeId = Maps.newHashMap();
    /**
     * A key is disk volume id, and a value is the load of this volume.
     * This load is measured by counting how many number of tasks are running.
     *
     * These disk volumes are kept in an order of ascending order of the volume id.
     * In other words, the head volume ids are likely to -1, meaning no given volume id.
     */
    private SortedMap<Integer, Integer> diskVolumeLoads = new TreeMap<>();
    /** The total number of remain tasks in this host */
    private AtomicInteger remainTasksNum = new AtomicInteger(0);

    public HostVolumeMapping(String host, String rack){
      this.host = host;
      this.rack = rack;
    }

    public synchronized void addTaskAttempt(int volumeId, TaskAttempt attemptId){
      synchronized (unassignedTaskForEachVolume){
        LinkedHashSet<TaskAttempt> list = unassignedTaskForEachVolume.get(volumeId);
        if (list == null) {
          list = new LinkedHashSet<>();
          unassignedTaskForEachVolume.put(volumeId, list);
        }
        list.add(attemptId);
      }

      remainTasksNum.incrementAndGet();

      if(!diskVolumeLoads.containsKey(volumeId)) diskVolumeLoads.put(volumeId, 0);
    }

    /**
     *  Priorities
     *  1. a task list in a volume of host
     *  2. unknown block or Non-splittable task in host
     *  3. remote tasks. unassignedTaskForEachVolume is only contained local task. so it will be null
     */
    public synchronized TaskAttemptId getLocalTask() {
      int volumeId = getLowestVolumeId();
      TaskAttemptId taskAttemptId = null;

      if (unassignedTaskForEachVolume.size() >  0) {
        int retry = diskVolumeLoads.size();
        do {
          //clean and get a remaining local task
          taskAttemptId = getAndRemove(volumeId);

          if (taskAttemptId == null) {
            //reassign next volume
            volumeId = getLowestVolumeId();
            retry--;
          } else {
            lastAssignedVolumeId.put(taskAttemptId, volumeId);
            break;
          }
        } while (retry > 0);
      } else {
        this.remainTasksNum.set(0);
      }

      return taskAttemptId;
    }

    public synchronized TaskAttemptId getTaskAttemptIdByRack(String rack) {
      TaskAttemptId taskAttemptId = null;

      if (unassignedTaskForEachVolume.size() > 0 && this.rack.equals(rack)) {
        int retry = unassignedTaskForEachVolume.size();
        do {
          //clean and get a remaining task
          int volumeId = getLowestVolumeId();
          taskAttemptId = getAndRemove(volumeId);
          if (taskAttemptId == null) {
            retry--;
          } else {
            break;
          }
        } while (retry > 0);
      }
      return taskAttemptId;
    }

    private synchronized TaskAttemptId getAndRemove(int volumeId){
      TaskAttemptId taskAttemptId = null;
      if(!unassignedTaskForEachVolume.containsKey(volumeId)) {
        if (volumeId > DataLocation.REMOTE_VOLUME_ID) {
          diskVolumeLoads.remove(volumeId);
        }
        return taskAttemptId;
      }

      LinkedHashSet<TaskAttempt> list = unassignedTaskForEachVolume.get(volumeId);
      if (list != null && !list.isEmpty()) {
        TaskAttempt taskAttempt;
        synchronized (unassignedTaskForEachVolume) {
          Iterator<TaskAttempt> iterator = list.iterator();
          taskAttempt = iterator.next();
          iterator.remove();
          remainTasksNum.decrementAndGet();
        }

        taskAttemptId = taskAttempt.getId();
        for (DataLocation location : taskAttempt.getTask().getDataLocations()) {
          HostVolumeMapping volumeMapping = scheduledRequests.leafTaskHostMapping.get(location.getHost());
          if (volumeMapping != null) {
            volumeMapping.removeTaskAttempt(location.getVolumeId(), taskAttempt);
          }
        }

        increaseConcurrency(volumeId);
      } else {
        unassignedTaskForEachVolume.remove(volumeId);
      }

      return taskAttemptId;
    }

    private synchronized void removeTaskAttempt(int volumeId, TaskAttempt taskAttempt){
      if(!unassignedTaskForEachVolume.containsKey(volumeId)) return;

      LinkedHashSet<TaskAttempt> tasks  = unassignedTaskForEachVolume.get(volumeId);
      if(tasks.remove(taskAttempt)) {
        remainTasksNum.getAndDecrement();
      }

      if(tasks.isEmpty()){
        unassignedTaskForEachVolume.remove(volumeId);
        if (volumeId > DataLocation.REMOTE_VOLUME_ID) {
          diskVolumeLoads.remove(volumeId);
        }
      }
    }

    /**
     * Increase the count of running tasks and disk loads for a certain task runner.
     *
     * @param volumeId Volume identifier
     * @return the volume load (i.e., how many running tasks use this volume)
     */
    private synchronized int increaseConcurrency(int volumeId) {

      int concurrency = 1;
      if (diskVolumeLoads.containsKey(volumeId)) {
        concurrency = diskVolumeLoads.get(volumeId) + 1;
      }

      if (volumeId > DataLocation.UNKNOWN_VOLUME_ID) {
        info(LOG, "Assigned host : " + host + ", Volume : " + volumeId + ", Concurrency : " + concurrency);
      } else if (volumeId == DataLocation.UNKNOWN_VOLUME_ID) {
        // this case is disabled namenode block meta or compressed text file or amazon s3
        info(LOG, "Assigned host : " + host + ", Unknown Volume : " + volumeId + ", Concurrency : " + concurrency);
      } else if (volumeId == DataLocation.REMOTE_VOLUME_ID) {
        // this case has processed all block on host and it will be assigned to remote
        info(LOG, "Assigned host : " + host + ", Remaining local tasks : " + getRemainingLocalTaskSize()
            + ", Remote Concurrency : " + concurrency + ", Unassigned volumes: " + unassignedTaskForEachVolume.size());
      }
      diskVolumeLoads.put(volumeId, concurrency);
      return concurrency;
    }

    /**
     * Decrease the count of running tasks of a certain task runner
     */
    private synchronized void decreaseConcurrency(int volumeId){
      if(diskVolumeLoads.containsKey(volumeId)){
        int concurrency = diskVolumeLoads.get(volumeId);
        if(concurrency > 0){
          diskVolumeLoads.put(volumeId, concurrency - 1);
        }
      }
    }

    /**
     *  volume of a host : 0 ~ n
     *  compressed task, amazon s3, unKnown volume : -1
     *  remote task : -2
     */
    public int getLowestVolumeId(){
      Map.Entry<Integer, Integer> volumeEntry = null;

      for (Map.Entry<Integer, Integer> entry : diskVolumeLoads.entrySet()) {
        if(volumeEntry == null) volumeEntry = entry;

        if (entry.getKey() != DataLocation.REMOTE_VOLUME_ID && volumeEntry.getValue() >= entry.getValue()) {
          volumeEntry = entry;
        }
      }

      if(volumeEntry != null){
        return volumeEntry.getKey();
      } else {
        return DataLocation.REMOTE_VOLUME_ID;
      }
    }

    public int getRemoteConcurrency(){
      return getVolumeConcurrency(DataLocation.REMOTE_VOLUME_ID);
    }

    public int getVolumeConcurrency(int volumeId){
      Integer size = diskVolumeLoads.get(volumeId);
      if(size == null) return 0;
      else return size;
    }

    public int getRemainingLocalTaskSize(){
      return remainTasksNum.get();
    }

    public String getHost() {
      return host;
    }

    public String getRack() {
      return rack;
    }
  }

  protected void cancel(TaskAttempt taskAttempt) {

    TaskAttemptToSchedulerEvent schedulerEvent = new TaskAttemptToSchedulerEvent(
        EventType.T_SCHEDULE, taskAttempt.getTask().getId().getExecutionBlockId(),
        null, taskAttempt);

    if(taskAttempt.isLeafTask()) {
      releaseTaskAttempt(taskAttempt);

      scheduledRequests.addLeafTask(schedulerEvent);
    } else {
      scheduledRequests.addNonLeafTask(schedulerEvent);
    }

    context.getMasterContext().getEventHandler().handle(
        new TaskAttemptEvent(taskAttempt.getId(), TaskAttemptEventType.TA_ASSIGN_CANCEL));
  }

  protected int cancel(List<TaskAllocationProto> tasks) {
    int canceled = 0;
    for (TaskAllocationProto proto : tasks) {
      TaskAttemptId attemptId = new TaskAttemptId(proto.getTaskRequest().getId());
      cancel(stage.getTask(attemptId.getTaskId()).getAttempt(attemptId));
      canceled++;
    }
    return canceled;
  }

  private class ScheduledRequests {
    // two list leafTasks and nonLeafTasks keep all tasks to be scheduled. Even though some task is included in
    // leafTaskHostMapping or leafTasksRackMapping, some task T will not be sent to a task runner
    // if the task is not included in leafTasks and nonLeafTasks.
    private final Set<TaskAttemptId> leafTasks = Collections.synchronizedSet(new HashSet<>());
    private final Set<TaskAttemptId> nonLeafTasks = Collections.synchronizedSet(new HashSet<>());
    private Map<String, HostVolumeMapping> leafTaskHostMapping = Maps.newConcurrentMap();
    private final Map<String, HashSet<TaskAttemptId>> leafTasksRackMapping = Maps.newConcurrentMap();

    protected void clear() {
      leafTasks.clear();
      nonLeafTasks.clear();
      leafTaskHostMapping.clear();
      leafTasksRackMapping.clear();
    }

    private void addLeafTask(TaskAttemptToSchedulerEvent event) {
      TaskAttempt taskAttempt = event.getTaskAttempt();
      List<DataLocation> locations = taskAttempt.getTask().getDataLocations();

      for (DataLocation location : locations) {
        String host = location.getHost();
        leafTaskHosts.add(host);

        HostVolumeMapping hostVolumeMapping = leafTaskHostMapping.get(host);
        if (hostVolumeMapping == null) {
          String rack = RackResolver.resolve(host).getNetworkLocation();
          hostVolumeMapping = new HostVolumeMapping(host, rack);
          leafTaskHostMapping.put(host, hostVolumeMapping);
        }
        hostVolumeMapping.addTaskAttempt(location.getVolumeId(), taskAttempt);

        if (LOG.isDebugEnabled()) {
          LOG.debug("Added attempt req to host " + host);
        }

        HashSet<TaskAttemptId> list = leafTasksRackMapping.get(hostVolumeMapping.getRack());
        if (list == null) {
          list = new HashSet<>();
          leafTasksRackMapping.put(hostVolumeMapping.getRack(), list);
        }

        list.add(taskAttempt.getId());

        if (LOG.isDebugEnabled()) {
          LOG.debug("Added attempt req to rack " + hostVolumeMapping.getRack());
        }
      }

      leafTasks.add(taskAttempt.getId());
    }

    private void addNonLeafTask(TaskAttemptToSchedulerEvent event) {
      nonLeafTasks.add(event.getTaskAttempt().getId());
    }

    public int leafTaskNum() {
      return leafTasks.size();
    }

    public int nonLeafTaskNum() {
      return nonLeafTasks.size();
    }

    private TaskAttemptId allocateLocalTask(String host){
      HostVolumeMapping hostVolumeMapping = leafTaskHostMapping.get(host);

      if (hostVolumeMapping != null) { //tajo host is located in hadoop datanode
        for (int i = 0; i < hostVolumeMapping.getRemainingLocalTaskSize(); i++) {
          TaskAttemptId attemptId = hostVolumeMapping.getLocalTask();

          if(attemptId == null) break;
          //find remaining local task
          if (leafTasks.contains(attemptId)) {
            leafTasks.remove(attemptId);
            return attemptId;
          }
        }
      }
      return null;
    }

    private TaskAttemptId allocateRackTask(String host) {

      List<HostVolumeMapping> remainingTasks = Lists.newArrayList(leafTaskHostMapping.values());
      String rack = RackResolver.resolve(host).getNetworkLocation();
      TaskAttemptId attemptId = null;

      if (remainingTasks.size() > 0) {
        synchronized (scheduledRequests) {
          //find largest remaining task of other host in rack
          Collections.sort(remainingTasks, new Comparator<HostVolumeMapping>() {
            @Override
            public int compare(HostVolumeMapping v1, HostVolumeMapping v2) {
              // descending remaining tasks
              if (v2.remainTasksNum.get() > v1.remainTasksNum.get()) {
                return 1;
              } else if (v2.remainTasksNum.get() == v1.remainTasksNum.get()) {
                return 0;
              } else {
                return -1;
              }
            }
          });
        }

        for (HostVolumeMapping tasks : remainingTasks) {
          for (int i = 0; i < tasks.getRemainingLocalTaskSize(); i++) {
            TaskAttemptId tId = tasks.getTaskAttemptIdByRack(rack);

            if (tId == null) break;

            if (leafTasks.contains(tId)) {
              leafTasks.remove(tId);
              attemptId = tId;
              break;
            }
          }
          if(attemptId != null) break;
        }
      }

      //find task in rack
      if (attemptId == null) {
        HashSet<TaskAttemptId> list = leafTasksRackMapping.get(rack);
        if (list != null) {
          synchronized (list) {
            Iterator<TaskAttemptId> iterator = list.iterator();
            while (iterator.hasNext()) {
              TaskAttemptId tId = iterator.next();
              iterator.remove();
              if (leafTasks.contains(tId)) {
                leafTasks.remove(tId);
                attemptId = tId;
                break;
              }
            }
          }
        }
      }

      return attemptId;
    }

    public void assignToLeafTasks(LinkedList<TaskRequestEvent> taskRequests) throws InterruptedException {
      Collections.shuffle(taskRequests);
      LinkedList<TaskRequestEvent> remoteTaskRequests = new LinkedList<>();
      String queryMasterHostAndPort = context.getMasterContext().getQueryMasterContext().getWorkerContext().
          getConnectionInfo().getHostAndQMPort();

      TaskRequestEvent taskRequest;
      while (leafTasks.size() > 0 && (!taskRequests.isEmpty() || !remoteTaskRequests.isEmpty())) {
        int localAssign = 0;
        int rackAssign = 0;

        taskRequest = taskRequests.pollFirst();
        if(taskRequest == null) { // if there are only remote task requests
          taskRequest = remoteTaskRequests.pollFirst();
        }

        // checking if this container is still alive.
        // If not, ignore the task request and stop the task runner
        WorkerConnectionInfo connectionInfo = context.getMasterContext().getWorkerMap().get(taskRequest.getWorkerId());
        if(connectionInfo == null) continue;

        // getting the hostname of requested node
        String host = connectionInfo.getHost();

        // if there are no worker matched to the hostname a task request
        if (!leafTaskHostMapping.containsKey(host) && !taskRequests.isEmpty()) {
          String normalizedHost = NetUtils.normalizeHost(host);

          if (!leafTaskHostMapping.containsKey(normalizedHost)) {
            // this case means one of either cases:
            // * there are no blocks which reside in this node.
            // * all blocks which reside in this node are consumed, and this task runner requests a remote task.
            // In this case, we transfer the task request to the remote task request list, and skip the followings.
            remoteTaskRequests.add(taskRequest);
            continue;
          } else {
            host = normalizedHost;
          }
        }

        if (LOG.isDebugEnabled()) {
          LOG.debug("assignToLeafTasks: " + taskRequest.getExecutionBlockId() + "," +
              "worker=" + connectionInfo.getHostAndPeerRpcPort());
        }

        //////////////////////////////////////////////////////////////////////
        // disk or host-local allocation
        //////////////////////////////////////////////////////////////////////
        TaskAttemptId attemptId = allocateLocalTask(host);
        int assignedVolume = DataLocation.REMOTE_VOLUME_ID;
        HostVolumeMapping hostVolumeMapping = leafTaskHostMapping.get(host);

        if (attemptId == null) { // if a local task cannot be found

          if(!taskRequests.isEmpty()) { //if other requests remains, move to remote list for better locality
            remoteTaskRequests.add(taskRequest);
            candidateWorkers.remove(connectionInfo.getId());
            continue;

          } else {
            if(hostVolumeMapping != null) {
              int nodes = context.getMasterContext().getWorkerMap().size();
              //this part is to control the assignment of tail and remote task balancing per node
              int tailLimit = 1;
              if (remainingScheduledObjectNum() > 0 && nodes > 0) {
                tailLimit = Math.max(remainingScheduledObjectNum() / nodes, 1);
              }

              //remote task throttling per node
              if (nodes > 1 && hostVolumeMapping.getRemoteConcurrency() >= tailLimit) {
                continue;
              } else {
                // assign to remote volume
                hostVolumeMapping.increaseConcurrency(assignedVolume);
              }
            }
          }

          //////////////////////////////////////////////////////////////////////
          // rack-local allocation
          //////////////////////////////////////////////////////////////////////
          attemptId = allocateRackTask(host);

          //////////////////////////////////////////////////////////////////////
          // random node allocation
          //////////////////////////////////////////////////////////////////////
          if (attemptId == null && leafTaskNum() > 0) {
            synchronized (leafTasks){
              attemptId = leafTasks.iterator().next();
              leafTasks.remove(attemptId);
            }
          }

          if (attemptId != null && hostVolumeMapping != null) {
            hostVolumeMapping.lastAssignedVolumeId.put(attemptId, assignedVolume);
          }
          rackAssign++;
        } else {
          if(hostVolumeMapping != null){
            //Set to real volume id
            assignedVolume = hostVolumeMapping.lastAssignedVolumeId.get(attemptId);
          }

          localAssign++;
        }

        if (attemptId != null) {
          Task task = stage.getTask(attemptId.getTaskId());
          TaskRequest taskAssign = new TaskRequestImpl(
              attemptId,
                  new ArrayList<>(task.getAllFragments()),
              "",
              false,
              LogicalNodeSerializer.serialize(task.getLogicalPlan()),
              context.getMasterContext().getQueryContext(),
              stage.getDataChannel(), stage.getBlock().getEnforcer(),
              queryMasterHostAndPort);

          NodeResource resource = new NodeResource(taskRequest.getResponseProto().getResource());

          if (checkIfInterQuery(stage.getMasterPlan(), stage.getBlock())) {
            taskAssign.setInterQuery();
          }

          //TODO send batch request
          BatchAllocationRequest.Builder requestProto = BatchAllocationRequest.newBuilder();
          requestProto.addTaskRequest(TaskAllocationProto.newBuilder()
              .setResource(resource.getProto()).setVolumeId(assignedVolume)
              .setTaskRequest(taskAssign.getProto()).build());

          requestProto.setExecutionBlockId(attemptId.getTaskId().getExecutionBlockId().getProto());
          context.getMasterContext().getEventHandler().handle(new TaskAttemptAssignedEvent(attemptId, connectionInfo));

          InetSocketAddress addr = stage.getAssignedWorkerMap().get(connectionInfo.getId());
          if (addr == null) addr = new InetSocketAddress(connectionInfo.getHost(), connectionInfo.getPeerRpcPort());

          AsyncRpcClient tajoWorkerRpc = null;
          CallFuture<BatchAllocationResponse> callFuture = new CallFuture<>();
          totalAttempts++;
          try {
            tajoWorkerRpc = RpcClientManager.getInstance().getClient(addr, TajoWorkerProtocol.class, true,
                rpcParams);

            TajoWorkerProtocol.TajoWorkerProtocolService tajoWorkerRpcClient = tajoWorkerRpc.getStub();
            tajoWorkerRpcClient.allocateTasks(callFuture.getController(), requestProto.build(), callFuture);

            BatchAllocationResponse responseProto = callFuture.get();

            if (responseProto.getCancellationTaskCount() > 0) {
              cancellation += cancel(responseProto.getCancellationTaskList());
              info(LOG, "Canceled requests: " + responseProto.getCancellationTaskCount() + " from " +  addr);
              continue;
            }
          } catch (ExecutionException | ConnectException e) {
            cancellation += cancel(requestProto.getTaskRequestList());

            warn(LOG, "Canceled requests: " + requestProto.getTaskRequestCount()
                + " by " + ExceptionUtils.getFullStackTrace(e));
            continue;
          } catch (InterruptedException e) {
            throw e;
          } catch (Exception e) {
            throw new TajoInternalError(e);
          }

          scheduledObjectNum--;
          totalAssigned++;
          hostLocalAssigned += localAssign;
          rackLocalAssigned += rackAssign;

          if (rackAssign > 0) {
            info(LOG, String.format("Assigned Local/Rack/Total: (%d/%d/%d), " +
                    "Attempted Cancel/Assign/Total: (%d/%d/%d), " +
                    "Locality: %.2f%%, Rack host: %s",
                hostLocalAssigned, rackLocalAssigned, totalAssigned,
                cancellation, totalAssigned, totalAttempts,
                ((double) hostLocalAssigned / (double) totalAssigned) * 100, host));
          }

        } else {
          throw new RuntimeException("Illegal State!!!!!!!!!!!!!!!!!!!!!");
        }
      }
    }

    private boolean checkIfInterQuery(MasterPlan masterPlan, ExecutionBlock block) {
      if (masterPlan.isRoot(block)) {
        return false;
      }

      ExecutionBlock parent = masterPlan.getParent(block);
      if (masterPlan.isRoot(parent) && parent.isUnionOnly()) {
        return false;
      }

      return true;
    }

    public void assignToNonLeafTasks(LinkedList<TaskRequestEvent> taskRequests) throws InterruptedException {
      Collections.shuffle(taskRequests);
      String queryMasterHostAndPort = context.getMasterContext().getQueryMasterContext().getWorkerContext().
          getConnectionInfo().getHostAndQMPort();

      TaskRequestEvent taskRequest;
      while (!taskRequests.isEmpty()) {
        taskRequest = taskRequests.pollFirst();
        LOG.debug("assignToNonLeafTasks: " + taskRequest.getExecutionBlockId());

        TaskAttemptId attemptId;
        // random allocation
        if (nonLeafTasks.size() > 0) {
          synchronized (nonLeafTasks){
            attemptId = nonLeafTasks.iterator().next();
            nonLeafTasks.remove(attemptId);
          }
          LOG.debug("Assigned based on * match");

          Task task;
          task = stage.getTask(attemptId.getTaskId());

          TaskRequest taskAssign = new TaskRequestImpl(
              attemptId,
              Lists.newArrayList(task.getAllFragments()),
              "",
              false,
              LogicalNodeSerializer.serialize(task.getLogicalPlan()),
              context.getMasterContext().getQueryContext(),
              stage.getDataChannel(),
              stage.getBlock().getEnforcer(),
              queryMasterHostAndPort);

          if (checkIfInterQuery(stage.getMasterPlan(), stage.getBlock())) {
            taskAssign.setInterQuery();
          }
          for(Map.Entry<String, Set<FetchProto>> entry: task.getFetchMap().entrySet()) {
            Collection<FetchProto> fetches = entry.getValue();
            if (fetches != null) {
              for (FetchProto fetch : fetches) {
                taskAssign.addFetch(fetch);
              }
            }
          }

          WorkerConnectionInfo connectionInfo =
              context.getMasterContext().getWorkerMap().get(taskRequest.getWorkerId());

          //TODO send batch request
          BatchAllocationRequest.Builder requestProto = BatchAllocationRequest.newBuilder();
          requestProto.addTaskRequest(TaskAllocationProto.newBuilder()
              .setResource(taskRequest.getResponseProto().getResource())
              .setTaskRequest(taskAssign.getProto()).build());

          requestProto.setExecutionBlockId(attemptId.getTaskId().getExecutionBlockId().getProto());
          context.getMasterContext().getEventHandler().handle(new TaskAttemptAssignedEvent(attemptId, connectionInfo));

          CallFuture<BatchAllocationResponse> callFuture = new CallFuture<>();

          InetSocketAddress addr = stage.getAssignedWorkerMap().get(connectionInfo.getId());
          if (addr == null) addr = new InetSocketAddress(connectionInfo.getHost(), connectionInfo.getPeerRpcPort());

          AsyncRpcClient tajoWorkerRpc;
          try {
            tajoWorkerRpc = RpcClientManager.getInstance().getClient(addr, TajoWorkerProtocol.class, true,
                rpcParams);

            TajoWorkerProtocol.TajoWorkerProtocolService tajoWorkerRpcClient = tajoWorkerRpc.getStub();
            tajoWorkerRpcClient.allocateTasks(callFuture.getController(), requestProto.build(), callFuture);

            BatchAllocationResponse responseProto = callFuture.get();

            if(responseProto.getCancellationTaskCount() > 0) {
              cancellation += cancel(responseProto.getCancellationTaskList());
              info(LOG, "Canceled requests: " + responseProto.getCancellationTaskCount() + " from " +  addr);
              continue;
            }

          } catch (ExecutionException | ConnectException e) {
            cancellation += cancel(requestProto.getTaskRequestList());
            warn(LOG, "Canceled requests: " + requestProto.getTaskRequestCount()
                + " by " + ExceptionUtils.getFullStackTrace(e));
            continue;
          } catch (InterruptedException e) {
            throw e;
          } catch (Exception e) {
            throw new TajoInternalError(e);
          }

          totalAssigned++;
          scheduledObjectNum--;
        }
      }
    }
  }
}