from .grammar import ProcessAsmLine, grammar, GenCode, ctrl_re, pred_re
from itertools import accumulate
import re

def StripSpace(file):
  # Replace all commands, space, tab with ''
  file = re.sub(r'\n\n', r'\n', file)
  file = re.sub(r'#.*', '', file)
  # Tailing space.
  file = re.sub(r'(?<=;).*', '', file)
  return file

def Assemble(file, include=None):
  '''
  return {
      RegCnt       => $regCnt,
      BarCnt       => $barCnt,
      ExitOffsets  => \@exitOffsets,
      CTAIDOffsets => \@ctaidOffsets,
      CTAIDZUsed   => $ctaidzUsed,
      KernelData   => \@codes,
  }
  '''
  # After preprocess.
  # for each line in the file.
  # 1. ProcessAsmLine
  #    Parse line to get: {ctrl}, {pred}, {op}, reset
  # 2. Apply register mapping.
  # 3. Parse op(flags) & operands?
  #    Need to write capture rules for instructions(oprands, flags)
  # 4. Generate binary code.
  #    Op | Flags | Operands
  file = StripSpace(file)
  num_registers = 8
  num_barriers  = 0
  smem_size     = 0
  const_size    = 0
  exit_offsets   = []
  labels = {} # Name => line_num
  branches = [] # Keep track of branch instructions (BRA)
  line_num = 0

  def GetSmemSize(file):
    smem_re = re.compile(r'^[\t ]*<SMEM>(.*?)\s*</SMEM>\n?', re.S | re.M | re.IGNORECASE)
    match = smem_re.search(file)
    if match:
      return smem_re.sub(r'', file), int(match.group(1))
    else:
      return file, 0
  file, smem_size = GetSmemSize(file)

  instructions = []
  for file_line_num, line in enumerate(file.split('\n')): # TODO: 
    if line == '':
      continue
    line_result = ProcessAsmLine(line, line_num)
    if(line_result):
      # Push instruction data to list
      instructions.append(line_result)
      if line_result['op'] == 'BRA':
        branches.append(line_result)
      if line_result['op'] == 'EXIT':
        exit_offsets.append(line_num * 16)
      line_num += 1
      continue # Ugly control flow
    label_result = re.match(r'(^[a-zA-Z]\w*):', line)
    # TODO: Move this to preprocess.
    if label_result:
      # Match a label
      labels[label_result.group(1)] = line_num
    else:
      print(line)
      raise Exception(f'Cannot recogonize {line} at line{file_line_num}.\n')

  # Append the tail BRA.
  instructions.append(ProcessAsmLine('--:-:-:Y:0  BRA -0x10;', len(instructions)+1))

  # Append NOPs to satisfy 128-bytes align.
  while len(instructions) % 8 != 0:
    # Pad NOP.
    instructions.append(ProcessAsmLine('--:-:-:Y:0  NOP;', len(instructions)+1))

  # Remap labels
  for bra_instr in branches:
    label = re.sub(r'^\s*', '', bra_instr['rest'])
    label = label.split(';')[0]
    relative_offset = (labels[label] - bra_instr['line_num'] - 1) * 0x10 
    bra_instr['rest'] = ' ' + hex(relative_offset) + ';'

  # Parse instructions.
  # Generate binary code. And insert to the instructions list.
  codes = []
  for instr in instructions:
    # Op, instr(rest part), 
    op = instr['op']
    rest = instr['rest']
    grams = grammar[op]
    # If match the rule of that instruction.
    for gram in grams:
      result = re.match(gram['rule'], op + rest)
      if result == None:
        continue
      else:
        c_gram = gram # Current grammar. Better name?
        break
    if result == None:
      print(repr(gram))
      raise Exception(f'Cannot recognize instruction {op+rest}')

    # FIXME (JO): Not all instructions use only 1 register. This part did not take that into account.

     # Update register count
    for reg in ['rd', 'rs0', 'rs1', 'rs2']:
      if reg not in result.groupdict():
        continue
      reg_data = result.groupdict()[reg]
      if reg_data == None or reg_data == 'RZ':
        continue
      else:
        reg_idx = int(reg_data[1:])
        if reg_idx + 1 > num_registers:
          num_registers = reg_idx + 1
    
    # Update barrier count.
    if op == 'BAR':
      barrier_idx = int(result.groupdict()['ibar'], 0)
      if barrier_idx >= 0xf:
        # TODO: Add line number here.
        raise Exception(f'Barrier index must be smaller than 15. {barrier_idx} found.')
      if barrier_idx + 1 > num_barriers:
        num_barriers = barrier_idx + 1


    code = GenCode(op, c_gram, result.groupdict(), instr)

    codes.append(code)

  # TODO: For some reasons, we need larger register count.
  if num_registers > 8:
    num_registers += 2



  return {
    # RegCnt
    'RegCnt'   : num_registers,
    # BarCnt
    'BarCnt'   : num_barriers,
    'SmemSize' : smem_size,
    'ConstSize': const_size,
    # ExitOffset
    'ExitOffset' : exit_offsets,
    # CTAIDOffset
    'KernelData' : codes
  }
    
register_map_re = re.compile(r'^[\t ]*<REGS>(.*?)\s*</REGS>\n?', re.S | re.M | re.IGNORECASE)
parameter_map_re = re.compile(r'^[\t ]*<PARAMS>(.*?)^\s*</PARAMS>\n?', re.S | re.M | re.IGNORECASE)
constant_map_re = re.compile(r'^[\t ]*<CONSTS>(.*?)^\s*</CONSTS>\n?', re.S | re.M | re.IGNORECASE)
def SetRegisterMap(file):
  # <regs>
  # 0, 1, 2, 3 : a0, a1, a2, a3
  # </regs>
  reg_map = {}
  regmap_result = register_map_re.findall(file)
  for match_item in regmap_result:
    for line_num, line in enumerate(match_item.split('\n')):
      # Strip commands
      line = re.sub(r'#.*', '', line)
      # Strip space
      line = re.sub(r'\s*', '', line)
      # Skip empty line
      if line == '':
        continue
      
      # reg_idx and reg_names
      reg_idx, reg_names = re.split('[:~]', line)
      auto = re.search('~', line)

      reg_idx = reg_idx.split(',')
      idx_list = []
      for item in reg_idx:
        if re.match('\d+$', item):
          idx_list.append(item)
        elif re.match('\d+-\d+$', item): # Support for range. E.g., 0-63
          match_item = re.match('(\d+)-(\d+)', item)
          reg_range = list(range(int(match_item[1]), int(match_item[2])+1))
          idx_list.extend(reg_range)

      reg_names = reg_names.split(',')
      name_list = []
      for item in reg_names:
        if re.match('\w+$', item):
          name_list.append(item)
        elif re.match('\w+<\d+-\d+>\w*$', item):
          match_item = re.match('(\w+)<(\d+)-(\d+)>(\w*)', item)
          name1, start, end, name2 = match_item[1], match_item[2], match_item[3], match_item[4]
          for i in range(int(start), int(end)+1):
            new_name = name1 + str(i) + name2
            name_list.append(new_name)

      if len(idx_list) != len(name_list) and not auto:
        raise Exception(f'Number of registers != number of register names at line {line_num+1}.\n') 
      elif len(idx_list) < len(name_list) and auto:
        raise Exception(f'Number of registers < number of register names at line {line_num+1}.\n')
      for i, name in enumerate(name_list):
        if name in reg_map:
          raise Exception(f'Register name {name} already defined at line {line_num+1}.\n')
        if not re.match(r'\w+', name):
          raise Exception(f'Invalid register name {name}, at line {line_num+1}.\n')
        reg_map[name] = idx_list[i]

  # Replace <regs> with ''
  file = register_map_re.sub('', file)

  return file, reg_map
def SetParameterMap(file):
  '''
  <PARAMS>
  input,  8
  output, 8
  </PARAMS>
  '''
  
  name_list = []
  size_list = []
  # Cannot use dict. Order information is needed.
  param_dict = {'name_list' : name_list, 'size_list' : size_list}
  parammap_result = parameter_map_re.findall(file)
  for match_item in parammap_result:
    for line_num, line in enumerate(match_item.split('\n')):
      # Replace commands and space
      line = re.sub(r'#.*', '', line)
      line = re.sub(r'\s*', '', line)
      if line == '':
        continue
      name, size = line.split(',')
      if name in name_list:
        raise Exception(f'Parameter name {name} already defined.\n')
      if not re.match(r'\w+', name):
        raise Exception(f'Invalid parameter name {name}, at line {line_num+1}.\n')
      size = int(size)
      if size % 4 != 0:
        raise Exception(f'Size of parameter {name} is not a multiplication of 4. Not supported.\n')
      name_list.append(name)
      size_list.append(size)
  
  # Delete parameter text.
  file = parameter_map_re.sub('', file)

  return file, param_dict

def SetConstsMap(file):
  '''
  <CONSTS>
  CONST_A,  8
  CONST_B, 8
  </CONSTS>
  '''
  name_list = []
  size_list = []
  # Cannot use dict. Order information is needed.
  const_dict = {'name_list' : name_list, 'size_list' : size_list}
  constmap_result = constant_map_re.findall(file)
  for match_item in constmap_result:
    for line_num, line in enumerate(match_item.split('\n')):
      # Replace commands and space
      line = re.sub(r'#.*', '', line)
      line = re.sub(r'\s*', '', line)
      if line == '':
        continue
      name, size = line.split(',')
      if name in name_list:
        raise Exception(f'Constant name {name} already defined.\n')
      if not re.match(r'\w+', name):
        raise Exception(f'Invalid Constant name {name}, at line {line_num+1}.\n')
      size = int(size)
      if size % 4 != 0:
        raise Exception(f'Size of Constant {name} is not a multiplication of 4. Not supported.\n')
      name_list.append(name)
      size_list.append(size)
  
  # Delete parameter text.
  file = constant_map_re.sub('', file)

  return file, const_dict

def GetParameterConstant(var_name, var_dict, bank, base, offset=0):
  index = var_dict['name_list'].index(var_name) # Use .index() is safe here. Elements are unique.
  prefix_sum = list(accumulate(var_dict['size_list']))
  size = var_dict['size_list'][index]
  
  if size - offset*4 < 0:
    raise Exception(f'Parameter {var_name} is of size {size}. Cannot have offset {offset}.')

  offset = prefix_sum[index] - size + offset * 4# FIXME: Currently we assume elements of all arrays are 4 Bytes in size.
  
  return 'c[0x{:x}]['.format(bank) + '0x{:x}'.format(base + offset) + ']'

# Replace register and parameter.
def ReplaceRegParamConstMap(file, reg_map, param_dict, const_dict):
  for key in reg_map.keys():
    if key in param_dict['name_list']:
      raise Exception(f'Name {key} defined both in register and parameters.\n')
    if key in const_dict['name_list']:
      raise Exception(f'Name {key} defined both in register and constants.\n')
  var_re = re.compile(fr'(?<!(?:\.))\b([a-zA-Z_]\w*)(?:\[(\d+)\]|\b)(?!\[0x)')
  def ReplaceVar(match, regs, params, consts):
    var = match.group(1)
    offset = match.group(2)
    try: 
      offset = int(offset)
    except (ValueError, TypeError):
      offset = 0

    if var in grammar:
      return var
    if var in reg_map:
      return 'R' + str(reg_map[var])
    if var in params['name_list']:
      return GetParameterConstant(var, params, 0, 0x160, offset)
    if var in consts['name_list']:
      return GetParameterConstant(var, consts, 3, 0x0, offset)
    else:
      # TODO: Or not to allow use RX in the code and raise exeception here.
      return var # In case of R0-R255, RZ, PR
  # Match rest first.
  file = var_re.sub(lambda match : ReplaceVar(match, reg_map, param_dict, const_dict), file)

  # Replace interior constant map
  constants = {
    'blockDim.x' : 'c[0x0][0x0]',
    'blockDim.y' : 'c[0x0][0x4]',
    'blockDim.z' : 'c[0x0][0x8]',
    'gridDim.x'  : 'c[0x0][0xc]',
    'gridDim.y'  : 'c[0x0][0x10]',
    'gridDim.z'  : 'c[0x0][0x14]'
  }
  const_re = re.compile('('+r'|'.join(constants.keys())+')')
  def ReplaceInteriorConst(match):
    return constants[match.group(1)]
  file = const_re.sub(ReplaceInteriorConst, file)

  return file
    
code_re = re.compile(r"^[\t ]*<CODE>(.*?)^\s*<\/CODE>\n?", re.MULTILINE|re.DOTALL|re.IGNORECASE)
def ExpandCode(file, include=None): # TODO: Better way to do this.
  # Execute include files.
  if include != None:
    for include_file in include:
      with open(include_file, 'r') as f:
        source = f.read()
        exec(source, globals())
  # Execute <CODE> block.
  def ReplaceCode(matchobj):
    exec(matchobj.group(1), globals())
    return out_
  return code_re.sub(ReplaceCode, file)

inline_re = re.compile(r'{(.*)?}', re.M)
def ExpandInline(file, include=None):
    # Execute include files.
  if include != None:
    for include_file in include:
      with open(include_file, 'r') as f:
        source = f.read()
        exec(source, globals())
  def ReplaceCode(matchobj):
    return str(eval(matchobj.group(1), globals()))
  return inline_re.sub(ReplaceCode, file)


  

if __name__ == '__main__':
  input_str = '''--:-:-:-:2    MOV R0, c[0x0][0x160];
--:-:-:-:2    MOV R1, c[0x0][0x164];
--:-:-:-:2    MOV R2, c[0x0][0x168];
--:-:-:-:5    MOV R3, c[0x0][0x16c];
--:-:-:-:2    STG.E.SYS [R0], R0;
--:-:-:-:2    STG.E.SYS [R0+4], R1;
--:-:-:-:2    STG.E.SYS [R2], R2;
--:-:-:-:2    STG.E.SYS [R2+4], R3;
--:-:-:-:2    EXIT;'''
  # ReplaceRegParamConstMap(input_str)