import tensorflow as tf


def aggregate(data, agg_idx, new_size, method="sum"):
  """ Aggregate data

  Args:
    data: tf tensor, see "unsorted_segment_x" in tf documents for more detail
    agg_idx: tf tensor of int, index for aggregation
    new_size: tf tensor of int, size of the data after aggregation
    method: aggregation method

  Returns:
    agg_data: tf tensor, aggregated data
  """

  if method == "sum":
    agg_data = tf.unsorted_segment_sum(data, agg_idx, new_size)
  elif method == "avg":
    agg_data = tf.unsorted_segment_sum(data, agg_idx, new_size)
    denom_const = tf.unsorted_segment_sum(tf.ones_like(data), agg_idx, new_size)
    agg_data = tf.div(agg_data, (denom_const + tf.constant(1.0e-10)))
  elif method == "max":
    agg_data = tf.unsorted_segment_max(data, agg_idx, new_size)
  elif method == "min":
    agg_data = tf.unsorted_segment_max(-data, agg_idx, new_size)
  else:
    raise ValueError("Unsupported aggregation method!")

  return agg_data