from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Lambda
import tensorflow as tf


def _nms(heat, kernel=3):
  hmax = K.pool2d(heat, (kernel, kernel), padding='same', pool_mode='max')
  keep = K.cast(K.equal(hmax, heat), K.floatx())
  return heat * keep


def _ctdet_decode(hm, reg, wh, k=100, output_stride=4):
  hm = K.sigmoid(hm)
  hm = _nms(hm)
  hm_shape = K.shape(hm)
  reg_shape = K.shape(reg)
  wh_shape = K.shape(wh)
  batch, width, cat = hm_shape[0], hm_shape[2], hm_shape[3]

  hm_flat = K.reshape(hm, (batch, -1))
  reg_flat = K.reshape(reg, (reg_shape[0], -1, reg_shape[-1]))
  wh_flat = K.reshape(wh, (wh_shape[0], -1, wh_shape[-1]))

  def _process_sample(args):
    _hm, _reg, _wh = args
    _scores, _inds = tf.math.top_k(_hm, k=k, sorted=True)
    _classes = K.cast(_inds % cat, 'float32')
    _inds = K.cast(_inds / cat, 'int32')
    _xs = K.cast(_inds % width, 'float32')
    _ys = K.cast(K.cast(_inds / width, 'int32'), 'float32')
    _wh = K.gather(_wh, _inds)
    _reg = K.gather(_reg, _inds)

    _xs = _xs + _reg[..., 0]
    _ys = _ys + _reg[..., 1]

    _x1 = _xs - _wh[..., 0] / 2
    _y1 = _ys - _wh[..., 1] / 2
    _x2 = _xs + _wh[..., 0] / 2
    _y2 = _ys + _wh[..., 1] / 2

    # rescale to image coordinates
    _x1 = output_stride * _x1
    _y1 = output_stride * _y1
    _x2 = output_stride * _x2
    _y2 = output_stride * _y2

    _detection = K.stack([_x1, _y1, _x2, _y2, _scores, _classes], -1)
    return _detection

  detections = K.map_fn(_process_sample, [hm_flat, reg_flat, wh_flat], dtype=K.floatx())
  return detections


def CtDetDecode(model, hm_index=3, reg_index=4, wh_index=5, k=100, output_stride=4):
  def _decode(args):
    hm, reg, wh = args
    return _ctdet_decode(hm, reg, wh, k=k, output_stride=output_stride)
  output = Lambda(_decode)([model.outputs[i] for i in [hm_index, reg_index, wh_index]])
  model = Model(model.input, output)
  return model


def _hpdet_decode(hm, wh, kps, reg, hm_hp, hp_offset, k=100, output_stride=4):
  hm = K.sigmoid(hm)
  hm = _nms(hm)
  hm_shape = K.shape(hm)
  reg_shape = K.shape(reg)
  wh_shape = K.shape(wh)
  kps_shape = K.shape(kps)
  batch, width, cat = hm_shape[0], hm_shape[2], hm_shape[3]

  hm_flat = K.reshape(hm, (batch, -1))
  reg_flat = K.reshape(reg, (reg_shape[0], -1, reg_shape[-1]))
  wh_flat = K.reshape(wh, (wh_shape[0], -1, wh_shape[-1]))
  kps_flat = K.reshape(kps, (kps_shape[0], -1, kps_shape[-1]))

  hm_hp = K.sigmoid(hm_hp)
  hm_hp = _nms(hm_hp)
  hm_hp_shape = K.shape(hm_hp)
  hp_offset_shape = K.shape(hp_offset)

  hm_hp_flat = K.reshape(hm_hp, (hm_hp_shape[0], -1, hm_hp_shape[-1]))
  hp_offset_flat = K.reshape(hp_offset, (hp_offset_shape[0], -1, hp_offset_shape[-1]))

  def _process_sample(args):
    _hm, _reg, _wh, _kps, _hm_hp, _hp_offset = args
    _scores, _inds = tf.math.top_k(_hm, k=k, sorted=True)
    _classes = K.cast(_inds % cat, 'float32')
    _inds = K.cast(_inds / cat, 'int32')
    _xs = K.cast(_inds % width, 'float32')
    _ys = K.cast(K.cast(_inds / width, 'int32'), 'float32')
    _wh = K.gather(_wh, _inds)
    _reg = K.gather(_reg, _inds)
    _kps = K.gather(_kps, _inds)

    # shift keypoints by their center
    _kps_x = _kps[:, ::2]
    _kps_y = _kps[:, 1::2]
    _kps_x = _kps_x + K.expand_dims(_xs, -1)  # k x J
    _kps_y = _kps_y + K.expand_dims(_ys, -1)  # k x J
    _kps = K.stack([_kps_x, _kps_y], -1)  # k x J x 2

    _xs = _xs + _reg[..., 0]
    _ys = _ys + _reg[..., 1]

    _x1 = _xs - _wh[..., 0] / 2
    _y1 = _ys - _wh[..., 1] / 2
    _x2 = _xs + _wh[..., 0] / 2
    _y2 = _ys + _wh[..., 1] / 2

    # snap center keypoints to the closest heatmap keypoint
    def _process_channel(args):
      __kps, __hm_hp = args
      thresh = 0.1
      __hm_scores, __hm_inds = tf.math.top_k(__hm_hp, k=k, sorted=True)
      __hm_xs = K.cast(__hm_inds % width, 'float32')
      __hm_ys = K.cast(K.cast(__hm_inds / width, 'int32'), 'float32')
      __hp_offset = K.gather(_hp_offset, __hm_inds)
      __hm_xs = __hm_xs + __hp_offset[..., 0]
      __hm_ys = __hm_ys + __hp_offset[..., 1]
      mask = K.cast(__hm_scores > thresh, 'float32')
      __hm_scores = (1. - mask) * -1. + mask * __hm_scores
      __hm_xs = (1. - mask) * -10000. + mask * __hm_xs
      __hm_ys = (1. - mask) * -10000. + mask * __hm_ys
      __hm_kps = K.stack([__hm_xs, __hm_ys], -1)  # k x 2
      __broadcast_hm_kps = K.expand_dims(__hm_kps, 1)  # k x 1 x 2
      __broadcast_kps = K.expand_dims(__kps, 0)  # 1 x k x 2
      dist = K.sqrt(K.sum(K.pow(__broadcast_kps - __broadcast_hm_kps, 2), 2))  # k, k
      min_dist = K.min(dist, 0)
      min_ind = K.argmin(dist, 0)
      __hm_scores = K.gather(__hm_scores, min_ind)
      __hm_kps = K.gather(__hm_kps, min_ind)
      mask = (K.cast(__hm_kps[..., 0] < _x1, 'float32') + K.cast(__hm_kps[..., 0] > _x2, 'float32') +
              K.cast(__hm_kps[..., 1] < _y1, 'float32') + K.cast(__hm_kps[..., 1] > _y2, 'float32') +
              K.cast(__hm_scores < thresh, 'float32') +
              K.cast(min_dist > 0.3 * (K.maximum(_wh[..., 0], _wh[..., 1])), 'float32'))
      mask = K.expand_dims(mask, -1)
      mask = K.cast(mask > 0, 'float32')
      __kps = (1. - mask) * __hm_kps + mask * __kps
      return __kps

    _kps = K.permute_dimensions(_kps, (1, 0, 2))  # J x k x 2
    _hm_hp = K.permute_dimensions(_hm_hp, (1, 0))  # J x -1
    _kps = K.map_fn(_process_channel, [_kps, _hm_hp], dtype='float32')
    _kps = K.reshape(K.permute_dimensions(_kps, (1, 2, 0)), (k, -1))  # k x J * 2

    # rescale to image coordinates
    _x1 = output_stride * _x1
    _y1 = output_stride * _y1
    _x2 = output_stride * _x2
    _y2 = output_stride * _y2
    _kps = output_stride * _kps

    _boxes = K.stack([_x1, _y1, _x2, _y2], -1)
    _scores = K.expand_dims(_scores, -1)
    _classes = K.expand_dims(_classes, -1)
    _detection = K.concatenate([_boxes, _scores, _kps, _classes], -1)
    return _detection

  detections = K.map_fn(_process_sample,
                        [hm_flat, reg_flat, wh_flat, kps_flat, hm_hp_flat, hp_offset_flat], dtype='float32')
  return detections


def HpDetDecode(model, hm_index=6, wh_index=11, kps_index=9, reg_index=10, hm_hp_index=7, hp_offset_index=8,
                k=100, output_stride=4):
  def _decode(args):
    hm, wh, kps, reg, hm_hp, hp_offset = args
    return _hpdet_decode(hm, wh, kps, reg, hm_hp, hp_offset, k=k, output_stride=output_stride)

  output = Lambda(_decode)(
    [model.outputs[i] for i in [hm_index, wh_index, kps_index, reg_index, hm_hp_index, hp_offset_index]])
  model = Model(model.input, output)
  return model