package exm.stc.ic.opt;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;

import org.apache.log4j.Logger;

import exm.stc.common.Settings;
import exm.stc.common.lang.Arg;
import exm.stc.common.lang.FnID;
import exm.stc.common.lang.ForeignFunctions;
import exm.stc.common.lang.Semantics;
import exm.stc.common.lang.TaskProp.TaskProps;
import exm.stc.common.lang.Types;
import exm.stc.common.lang.Types.Type;
import exm.stc.common.lang.Var;
import exm.stc.common.lang.Var.Alloc;
import exm.stc.common.lang.Var.DefType;
import exm.stc.common.lang.Var.VarProvenance;
import exm.stc.common.lang.WaitVar;
import exm.stc.common.util.Pair;
import exm.stc.ic.WrapUtil;
import exm.stc.ic.tree.ICInstructions.FunctionCall;
import exm.stc.ic.tree.ICInstructions.Instruction;
import exm.stc.ic.tree.ICTree.Block;
import exm.stc.ic.tree.ICTree.Function;
import exm.stc.ic.tree.ICTree.Program;
import exm.stc.ic.tree.ICTree.RenameMode;
import exm.stc.ic.tree.TurbineOp;

/**
 * Optimize function signature
 */
public class FunctionSignature implements OptimizerPass {

  @Override
  public String getPassName() {
    return "Function signature changing";
  }

  @Override
  public String getConfigEnabledKey() {
    return Settings.OPT_FUNCTION_SIGNATURE;
  }


  /**
   * Switch to passing values directly.
   * Do this before function inlining since function inlining
   * will clean it up.
   * @param logger
   * @param program
   */
  @Override
  public void optimize(Logger logger, Program program) {
    Set<FnID> usedFnIDs = new HashSet<FnID>(program.getFunctionMap().keySet());
    Map<FnID, Function> toInline = new HashMap<FnID, Function>();
    ListIterator<Function> fnIt = program.functionIterator();
    while (fnIt.hasNext()) {
      Function fn = fnIt.next();
      Function newFn = switchToValuePassing(logger, program.foreignFunctions(),
                                            fn, usedFnIDs);
      if (newFn != null) {
        fnIt.remove(); // Remove old function
        fnIt.add(newFn);
        usedFnIDs.add(newFn.id());

        // We should inline
        toInline.put(fn.id(), fn);
      }
    }

    // Inline all calls to the old function
    FunctionInline.inlineAllOccurrences(logger, program, toInline);
  }

  private Function switchToValuePassing(Logger logger, ForeignFunctions foreignFuncs,
            Function fn, Set<FnID> usedFnIDs) {
    if (fn.blockingInputs().isEmpty())
      return null;

    // Collect list of variables we could switch
    List<Var> switchVars = new ArrayList<Var>();

    for (WaitVar input: fn.blockingInputs()) {
      // See if we can switch to value version
      if (Types.isPrimFuture(input.var)) {
        Type valueT = Types.retrievedType(input.var.type());
        if (Semantics.canPassToChildTask(valueT)) {
          switchVars.add(input.var);
        }
      }
    }

    if (switchVars.isEmpty())
      return null;

    List<Pair<Var, Var>> futValPairs = createValueVars(fn, switchVars);
    Map<Var, Var> switched = new HashMap<Var, Var>();
    for (Pair<Var, Var> fv: futValPairs) {
      switched.put(fv.val1, fv.val2);
      assert(fv.val2 != null);
    }
    List<Var> newIList = buildNewInputList(fn, switched);
    FnID newID = selectUniqueID(fn.id(), usedFnIDs);

    // Block that calls into new version
    Block callNewFunction = callNewFunctionCode(foreignFuncs, fn, newID,
                                                switchVars);
    Block newBlock = fn.swapBlock(callNewFunction);


    // Declare variables in new block and load values
    // Other optimization passes will clear up later
    for (Pair<Var, Var> fv: futValPairs) {
      // declare local stack var and replace argument in
      Var tmpfuture = new Var(fv.val1.type(), fv.val1.name(),
                       Alloc.STACK, DefType.LOCAL_USER,
                       VarProvenance.renamed(fv.val1));
      newBlock.renameVars(fn.id(),
            Collections.singletonMap(fv.val1, tmpfuture.asArg()),
            RenameMode.REPLACE_VAR, true);
      newBlock.addVariable(tmpfuture);
      Instruction store =
          TurbineOp.storePrim(tmpfuture, fv.val2.asArg());
      newBlock.addInstructionFront(store);
    }

    List<WaitVar> newBlocking = new ArrayList<WaitVar>();
    for (WaitVar wv: fn.blockingInputs()) {
      if (!switchVars.contains(wv.var)) {
        newBlocking.add(wv);
      }
    }
    return new Function(newID, newIList, newBlocking,
                        fn.getOutputList(), fn.mode(), newBlock);
  }

  /**
   *
   * @return new main block for fn
   */
  private Block callNewFunctionCode(ForeignFunctions foreignFuncs,
                              Function fn, FnID newFunctionID,
                                    List<Var> switched) {
    Block main = new Block(fn);
    // these vars should already be closed.
    // load values and call new function

    List<Arg> fetched = OptUtil.fetchValuesOf(main, switched, false, false);

    List<Arg> callInputs = new ArrayList<Arg>(fn.getInputList().size());
    for (Var inArg: fn.getInputList()) {
      int ix = switched.indexOf(inArg);
      if (ix >= 0) {
        callInputs.add(fetched.get(ix));
      } else {
        callInputs.add(inArg.asArg());
      }
    }

    FunctionCall callNew = FunctionCall.createFunctionCall(newFunctionID,
                            fn.getOutputList(), callInputs, fn.mode(),
                            new TaskProps(), foreignFuncs);
    main.addInstruction(callNew);
    return main;
  }

  private List<Var> buildNewInputList(Function fn, Map<Var, Var> switched) {
    List<Var> newIList = new ArrayList<Var>();
    for (Var oldInput: fn.getInputList()) {
      if (switched.containsKey(oldInput)) {
        newIList.add(switched.get(oldInput));
      } else {
        newIList.add(oldInput);
      }
    }
    return newIList;
  }

  private List<Pair<Var, Var>> createValueVars(Function fn, List<Var> switchVars) {
    List<Pair<Var, Var>> futValPairs = new ArrayList<Pair<Var,Var>>();
    // Create value vars
    for (Var toSwitch: switchVars) {
      // a value var that will have unique name in new context
      String valVarName = OptUtil.optVPrefix(fn.mainBlock(), toSwitch);
      Var valVar = WrapUtil.createValueVar(valVarName,
          Types.retrievedType(toSwitch), toSwitch);

      futValPairs.add(Pair.create(toSwitch, valVar));
    }
    return futValPairs;
  }

  private FnID selectUniqueID(FnID id, Set<FnID> used) {
    int nameCounter = 1;
    String prefix = id.uniqueName();
    String newName;
    FnID newID;
    do {
      newName = prefix + "-" + nameCounter;
      nameCounter++;

      newID = new FnID(newName, id.originalName());
    } while (used.contains(newID));
    return newID;
  }

}