import math
import tensorflow as tf
import util

import srl_ops


def flatten_emb(emb):
  num_sentences = tf.shape(emb)[0]
  max_sentence_length = tf.shape(emb)[1]
  emb_rank = len(emb.get_shape())
  if emb_rank  == 2:
    flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length])
  elif emb_rank == 3:
    flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length, util.shape(emb, 2)])
  else:
    raise ValueError("Unsupported rank: {}".format(emb_rank))
  return flattened_emb


def flatten_emb_by_sentence(emb, text_len_mask):
  num_sentences = tf.shape(emb)[0]
  max_sentence_length = tf.shape(emb)[1]
  flattened_emb = flatten_emb(emb)
  return tf.boolean_mask(flattened_emb,
                         tf.reshape(text_len_mask, [num_sentences * max_sentence_length]))


def batch_gather(emb, indices):
  # TODO: Merge with util.batch_gather.
  """
  Args:
    emb: Shape of [num_sentences, max_sentence_length, (emb)]
    indices: Shape of [num_sentences, k, (l)]
  """
  num_sentences = tf.shape(emb)[0] 
  max_sentence_length = tf.shape(emb)[1] 
  flattened_emb = flatten_emb(emb)  # [num_sentences * max_sentence_length, emb]
  offset = tf.expand_dims(tf.range(num_sentences) * max_sentence_length, 1)  # [num_sentences, 1]
  if len(indices.get_shape()) == 3:
    offset = tf.expand_dims(offset, 2)  # [num_sentences, 1, 1]
  return tf.gather(flattened_emb, indices + offset) 
  

def lstm_contextualize(text_emb, text_len, config, lstm_dropout):
  num_sentences = tf.shape(text_emb)[0]
  current_inputs = text_emb  # [num_sentences, max_sentence_length, emb]
  for layer in xrange(config["contextualization_layers"]):
    with tf.variable_scope("layer_{}".format(layer)):
      with tf.variable_scope("fw_cell"):
        cell_fw = util.CustomLSTMCell(config["contextualization_size"], num_sentences, lstm_dropout)
      with tf.variable_scope("bw_cell"):
        cell_bw = util.CustomLSTMCell(config["contextualization_size"], num_sentences, lstm_dropout)
      state_fw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]),
                                               tf.tile(cell_fw.initial_state.h, [num_sentences, 1]))
      state_bw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]),
                                               tf.tile(cell_bw.initial_state.h, [num_sentences, 1]))
      (fw_outputs, bw_outputs), _ = tf.nn.bidirectional_dynamic_rnn(
          cell_fw=cell_fw,
          cell_bw=cell_bw,
          inputs=current_inputs,
          sequence_length=text_len,
          initial_state_fw=state_fw,
          initial_state_bw=state_bw)
      text_outputs = tf.concat([fw_outputs, bw_outputs], 2)  # [num_sentences, max_sentence_length, emb]
      text_outputs = tf.nn.dropout(text_outputs, lstm_dropout)
      if layer > 0:
        highway_gates = tf.sigmoid(util.projection(
            text_outputs, util.shape(text_outputs, 2)))  # [num_sentences, max_sentence_length, emb]
        text_outputs = highway_gates * text_outputs + (1 - highway_gates) * current_inputs
      current_inputs = text_outputs

  return text_outputs  # [num_sentences, max_sentence_length, emb]


def get_span_candidates(text_len, max_sentence_length, max_mention_width):
  """Get a list of candidate spans up to length W.
  Args:
    text_len: Tensor of [num_sentences,]
    max_sentence_length: Integer scalar.
    max_mention_width: Integer.
  """
  num_sentences = util.shape(text_len, 0)
  candidate_starts = tf.tile(
      tf.expand_dims(tf.expand_dims(tf.range(max_sentence_length), 0), 1),
      [num_sentences, max_mention_width, 1])  # [num_sentences, max_mention_width, max_sentence_length]
  candidate_widths = tf.expand_dims(tf.expand_dims(tf.range(max_mention_width), 0), 2)  # [1, max_mention_width, 1]
  candidate_ends = candidate_starts + candidate_widths  # [num_sentences, max_mention_width, max_sentence_length]
  
  candidate_starts = tf.reshape(candidate_starts, [num_sentences, max_mention_width * max_sentence_length])
  candidate_ends = tf.reshape(candidate_ends, [num_sentences, max_mention_width * max_sentence_length])
  candidate_mask = tf.less(
      candidate_ends,
      tf.tile(tf.expand_dims(text_len, 1), [1, max_mention_width * max_sentence_length])
  )  # [num_sentences, max_mention_width * max_sentence_length]

  # Mask to avoid indexing error.
  candidate_starts = tf.multiply(candidate_starts, tf.to_int32(candidate_mask))
  candidate_ends = tf.multiply(candidate_ends, tf.to_int32(candidate_mask))
  return candidate_starts, candidate_ends, candidate_mask  


def get_span_emb(head_emb, context_outputs, span_starts, span_ends, config, dropout):
  """Compute span representation shared across tasks.
  Args:
    head_emb: Tensor of [num_words, emb]
    context_outputs: Tensor of [num_words, emb]
    span_starts: [num_spans]
    span_ends: [num_spans]
  """
  text_length = util.shape(context_outputs, 0)
  num_spans = util.shape(span_starts, 0)

  span_start_emb = tf.gather(context_outputs, span_starts)  # [num_words, emb]
  span_end_emb = tf.gather(context_outputs, span_ends)  # [num_words, emb]
  span_emb_list = [span_start_emb, span_end_emb]

  span_width = 1 + span_ends - span_starts # [num_spans]
  max_arg_width = config["max_arg_width"]
  num_heads = config["num_attention_heads"]

  if config["use_features"]:
    span_width_index = span_width - 1  # [num_spans]
    span_width_emb = tf.gather(
        tf.get_variable("span_width_embeddings", [max_arg_width, config["feature_size"]]),
        span_width_index)  # [num_spans, emb]
    span_width_emb = tf.nn.dropout(span_width_emb, dropout)
    span_emb_list.append(span_width_emb)

  head_scores = None
  span_text_emb = None
  span_indices = None
  span_indices_log_mask = None

  if config["model_heads"]:
    span_indices = tf.minimum(
        tf.expand_dims(tf.range(max_arg_width), 0) + tf.expand_dims(span_starts, 1),
        text_length - 1)  # [num_spans, max_span_width]
    span_text_emb = tf.gather(head_emb, span_indices)  # [num_spans, max_arg_width, emb]
    span_indices_log_mask = tf.log(
        tf.sequence_mask(span_width, max_arg_width, dtype=tf.float32)) # [num_spans, max_arg_width]
    with tf.variable_scope("head_scores"):
      head_scores = util.projection(context_outputs, num_heads)  # [num_words, num_heads]
    span_attention = tf.nn.softmax(
      tf.gather(head_scores, span_indices) + tf.expand_dims(span_indices_log_mask, 2),
      dim=1)  # [num_spans, max_arg_width, num_heads]
    span_head_emb = tf.reduce_sum(span_attention * span_text_emb, 1)  # [num_spans, emb]
    span_emb_list.append(span_head_emb)

  span_emb = tf.concat(span_emb_list, 1) # [num_spans, emb]
  return span_emb, head_scores, span_text_emb, span_indices, span_indices_log_mask


def get_unary_scores(span_emb, config, dropout, num_labels = 1, name="span_scores"):
  """Compute span score with FFNN(span embedding).
  Args:
    span_emb: Tensor of [num_sentences, num_spans, emb].
  """
  with tf.variable_scope(name):
    scores = util.ffnn(span_emb, config["ffnn_depth"], config["ffnn_size"], num_labels,
                       dropout)  # [num_sentences, num_spans, num_labels] or [k, num_labels]
  if num_labels == 1:
    scores = tf.squeeze(scores, -1)  # [num_sentences, num_spans] or [k]
  return scores


def get_srl_scores(arg_emb, pred_emb, arg_scores, pred_scores, num_labels, config, dropout):
  num_sentences = util.shape(arg_emb, 0)
  num_args = util.shape(arg_emb, 1)
  num_preds = util.shape(pred_emb, 1)

  arg_emb_expanded = tf.expand_dims(arg_emb, 2)  # [num_sents, num_args, 1, emb]
  pred_emb_expanded = tf.expand_dims(pred_emb, 1)  # [num_sents, 1, num_preds, emb] 
  arg_emb_tiled = tf.tile(arg_emb_expanded, [1, 1, num_preds, 1])  # [num_sentences, num_args, num_preds, emb]
  pred_emb_tiled = tf.tile(pred_emb_expanded, [1, num_args, 1, 1])  # [num_sents, num_args, num_preds, emb]

  pair_emb_list = [arg_emb_tiled, pred_emb_tiled]
  pair_emb = tf.concat(pair_emb_list, 3)  # [num_sentences, num_args, num_preds, emb]
  pair_emb_size = util.shape(pair_emb, 3)
  flat_pair_emb = tf.reshape(pair_emb, [num_sentences * num_args * num_preds, pair_emb_size])

  flat_srl_scores = get_unary_scores(flat_pair_emb, config, dropout, num_labels - 1,
      "predicate_argument_scores")  # [num_sentences * num_args * num_predicates, 1]
  srl_scores = tf.reshape(flat_srl_scores, [num_sentences, num_args, num_preds, num_labels - 1])
  srl_scores += tf.expand_dims(tf.expand_dims(arg_scores, 2), 3) + tf.expand_dims(
      tf.expand_dims(pred_scores, 1), 3)  # [num_sentences, 1, max_num_preds, num_labels-1]
  
  dummy_scores = tf.zeros([num_sentences, num_args, num_preds, 1], tf.float32)
  srl_scores = tf.concat([dummy_scores, srl_scores], 3)  # [num_sentences, max_num_args, max_num_preds, num_labels] 
  return srl_scores  # [num_sentences, num_args, num_predicates, num_labels]


def get_batch_topk(candidate_starts, candidate_ends, candidate_scores, topk_ratio, text_len,
                   max_sentence_length, sort_spans=False, enforce_non_crossing=True):
  """
  Args:
    candidate_starts: [num_sentences, max_num_candidates]
    candidate_mask: [num_sentences, max_num_candidates]
    topk_ratio: A float number.
    text_len: [num_sentences,]
    max_sentence_length:
    enforce_non_crossing: Use regular top-k op if set to False.
 """
  num_sentences = util.shape(candidate_starts, 0)
  max_num_candidates = util.shape(candidate_starts, 1)

  topk = tf.maximum(tf.to_int32(tf.floor(tf.to_float(text_len) * topk_ratio)),
                    tf.ones([num_sentences,], dtype=tf.int32))  # [num_sentences]

  predicted_indices = srl_ops.extract_spans(
      candidate_scores, candidate_starts, candidate_ends, topk, max_sentence_length,
      sort_spans, enforce_non_crossing)  # [num_sentences, max_num_predictions]
  predicted_indices.set_shape([None, None])

  predicted_starts = batch_gather(candidate_starts, predicted_indices)  # [num_sentences, max_num_predictions]
  predicted_ends = batch_gather(candidate_ends, predicted_indices)  # [num_sentences, max_num_predictions]
  predicted_scores = batch_gather(candidate_scores, predicted_indices)  # [num_sentences, max_num_predictions]

  return predicted_starts, predicted_ends, predicted_scores, topk, predicted_indices


def get_srl_labels(arg_starts, arg_ends, predicates, labels, max_sentence_length):
  """
  Args:
    arg_starts: [num_sentences, max_num_args]
    arg_ends: [num_sentences, max_num_args]
    predicates: [num_sentences, max_num_predicates]
    labels: Dictionary of label tensors.
    max_sentence_length: An integer scalar.
  """
  num_sentences = util.shape(arg_starts, 0)
  max_num_args = util.shape(arg_starts, 1)
  max_num_preds = util.shape(predicates, 1)
  sentence_indices_2d = tf.tile(
      tf.expand_dims(tf.expand_dims(tf.range(num_sentences), 1), 2),
      [1, max_num_args, max_num_preds])  # [num_sentences, max_num_args, max_num_preds]
  tiled_arg_starts = tf.tile(
      tf.expand_dims(arg_starts, 2),
      [1, 1, max_num_preds])  # [num_sentences, max_num_args, max_num_preds]
  tiled_arg_ends = tf.tile(
      tf.expand_dims(arg_ends, 2),
      [1, 1, max_num_preds])  # [num_sentences, max_num_args, max_num_preds]
  tiled_predicates = tf.tile(
      tf.expand_dims(predicates, 1),
      [1, max_num_args, 1])  # [num_sentences, max_num_args, max_num_preds]
  pred_indices = tf.concat([
      tf.expand_dims(sentence_indices_2d, 3),
      tf.expand_dims(tiled_arg_starts, 3),
      tf.expand_dims(tiled_arg_ends, 3),
      tf.expand_dims(tiled_predicates, 3)], axis=3)  # [num_sentences, max_num_args, max_num_preds, 4]
 
  dense_srl_labels = get_dense_span_labels(
      labels["arg_starts"], labels["arg_ends"], labels["arg_labels"], labels["srl_len"], max_sentence_length,
      span_parents=labels["predicates"])  # [num_sentences, max_sent_len, max_sent_len, max_sent_len]
 
  srl_labels = tf.gather_nd(params=dense_srl_labels, indices=pred_indices)  # [num_sentences, max_num_args]
  return srl_labels


def get_dense_span_labels(span_starts, span_ends, span_labels, num_spans, max_sentence_length, span_parents=None):
  """Utility function to get dense span or span-head labels.
  Args:
    span_starts: [num_sentences, max_num_spans]
    span_ends: [num_sentences, max_num_spans]
    span_labels: [num_sentences, max_num_spans]
    num_spans: [num_sentences,]
    max_sentence_length:
    span_parents: [num_sentences, max_num_spans]. Predicates in SRL.
  """
  num_sentences = util.shape(span_starts, 0)
  max_num_spans = util.shape(span_starts, 1)
  # For padded spans, we have starts = 1, and ends = 0, so they don't collide with any existing spans.
  span_starts += (1 - tf.sequence_mask(num_spans, dtype=tf.int32))  # [num_sentences, max_num_spans]
  sentence_indices = tf.tile(
      tf.expand_dims(tf.range(num_sentences), 1),
      [1, max_num_spans])  # [num_sentences, max_num_spans]
  sparse_indices = tf.concat([
      tf.expand_dims(sentence_indices, 2),
      tf.expand_dims(span_starts, 2),
      tf.expand_dims(span_ends, 2)], axis=2)  # [num_sentences, max_num_spans, 3]
  if span_parents is not None:
    sparse_indices = tf.concat([
      sparse_indices, tf.expand_dims(span_parents, 2)], axis=2)  # [num_sentenes, max_num_spans, 4]

  rank = 3 if (span_parents is None) else 4
  # (sent_id, span_start, span_end) -> span_label
  dense_labels = tf.sparse_to_dense(
      sparse_indices = tf.reshape(sparse_indices, [num_sentences * max_num_spans, rank]),
      output_shape = [num_sentences] + [max_sentence_length] * (rank - 1),
      sparse_values = tf.reshape(span_labels, [-1]),
      default_value = 0,
      validate_indices = False)  # [num_sentences, max_sent_len, max_sent_len]
  return dense_labels
    

def get_srl_softmax_loss(srl_scores, srl_labels, num_predicted_args, num_predicted_preds):
  """Softmax loss with 2-D masking (for SRL).
  Args:
    srl_scores: [num_sentences, max_num_args, max_num_preds, num_labels]
    srl_labels: [num_sentences, max_num_args, max_num_preds]
    num_predicted_args: [num_sentences]
    num_predicted_preds: [num_sentences]
  """
  max_num_args = util.shape(srl_scores, 1)
  max_num_preds = util.shape(srl_scores, 2)
  num_labels = util.shape(srl_scores, 3)
  args_mask = tf.sequence_mask(num_predicted_args, max_num_args)  # [num_sentences, max_num_args]
  preds_mask = tf.sequence_mask(num_predicted_preds, max_num_preds)  # [num_sentences, max_num_preds]
  srl_loss_mask = tf.logical_and(
      tf.expand_dims(args_mask, 2),  # [num_sentences, max_num_args, 1]
      tf.expand_dims(preds_mask, 1)  # [num_sentences, 1, max_num_preds]
  )  # [num_sentences, max_num_args, max_num_preds]
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=tf.reshape(srl_labels, [-1]),
      logits=tf.reshape(srl_scores, [-1, num_labels]),
      name="srl_softmax_loss")  # [num_sentences * max_num_args * max_num_preds]
  loss = tf.boolean_mask(loss, tf.reshape(srl_loss_mask, [-1]))
  loss.set_shape([None])
  loss = tf.reduce_sum(loss)
  return loss


def get_softmax_loss(scores, labels, candidate_mask):
  """Softmax loss with 1-D masking. (on Unary factors)
  Args:
    scores: [num_sentences, max_num_candidates, num_labels]
    labels: [num_sentences, max_num_candidates]
    candidate_mask: [num_sentences, max_num_candidates]
  """
  max_num_candidates = util.shape(scores, 1)
  num_labels = util.shape(scores, 2)
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=tf.reshape(labels, [-1]), 
      logits=tf.reshape(scores, [-1, num_labels]),
      name="softmax_loss")  # [num_sentences, max_num_candidates]
  loss = tf.boolean_mask(loss, tf.reshape(candidate_mask, [-1]))
  loss.set_shape([None])
  return loss