import tensorflow as tf
import tensorflow.contrib.layers as ly
from util import lrelu
import cv2
import math
from pdf_sample_layer import pdf_sample
from util import enrich_image_input
from util import STATE_DROPOUT_BEGIN, STATE_REWARD_DIM, STATE_STEP_DIM, STATE_STOPPED_DIM


def feature_extractor(net, output_dim, cfg):
  net = net - 0.5
  min_feature_map_size = 4
  assert output_dim % (
      min_feature_map_size**2) == 0, 'output dim=%d' % output_dim
  size = int(net.get_shape()[2])
  print('Agent CNN:')
  channels = cfg.base_channels
  print('    ', str(net.get_shape()))
  size /= 2
  net = ly.conv2d(
      net, num_outputs=channels, kernel_size=4, stride=2, activation_fn=lrelu)
  print('    ', str(net.get_shape()))
  while size > min_feature_map_size:
    if size == min_feature_map_size * 2:
      channels = output_dim / (min_feature_map_size**2)
    else:
      channels *= 2
    assert size % 2 == 0
    size /= 2
    net = ly.conv2d(
        net, num_outputs=channels, kernel_size=4, stride=2, activation_fn=lrelu)
    print('    ', str(net.get_shape()))
  print('before fc: ', net.get_shape()[1])
  net = tf.reshape(net, [-1, output_dim])
  net = tf.nn.dropout(net, cfg.dropout_keep_prob)
  return net


# Output: float \in [0, 1]
def agent_generator(inp, is_train, progress, cfg, high_res=None, alex_in=None):
  net, z, states = inp
  filters = cfg.filters

  filters = [x(net, cfg) for x in filters]

  selection_noise = z[:, 0:1]
  filtered_images = []
  filter_debug_info = []
  high_res_outputs = []

  if cfg.shared_feature_extractor:
    filter_features = feature_extractor(
        net=enrich_image_input(cfg, net, states),
        output_dim=cfg.feature_extractor_dims,
        cfg=cfg)
    # filter_features = ly.dropout(filter_features)
  for j, filter in enumerate(filters):
    with tf.variable_scope('filter_%d' % j):
      print('    creating filter:', j, 'name:', str(filter.__class__), 'abbr.',
            filter.get_short_name())
      if not cfg.shared_feature_extractor:
        filter_features = \
            feature_extractor(net=enrich_image_input(cfg, net),
                              output_dim=cfg.feature_extractor_dims, cfg=cfg)
      print('      filter_features:', filter_features.shape)
      filtered_image_batch, high_res_output, per_filter_debug_info = filter.apply(
          net, filter_features, high_res=high_res)
      high_res_outputs.append(high_res_output)
      filtered_images.append(filtered_image_batch)
      filter_debug_info.append(per_filter_debug_info)
      print('      output:', filtered_image_batch.shape)

  # [batch_size, #filters, H, W, C]
  for img in filtered_images:
    print('img', img.shape)
  filtered_images = tf.stack(values=filtered_images, axis=1)
  print('    filtered_images:', filtered_images.shape)

  with tf.variable_scope('action_selection'):
    selector_features = feature_extractor(
        net=enrich_image_input(cfg, net, states),
        output_dim=cfg.feature_extractor_dims,
        cfg=cfg)
    print('    selector features:', selector_features.shape)

    selector_features = ly.fully_connected(
        selector_features,
        num_outputs=cfg.fc1_size,
        scope='selector_fc1',
        activation_fn=lrelu)

    # selector_features = ly.dropout(selector_features)

    pdf = ly.fully_connected(
        selector_features,
        num_outputs=len(filters),
        activation_fn=None,
        scope='selector_fc2')
    pdf = tf.nn.softmax(pdf) + 1e-37
    print('    pdf_filter', pdf[:, 1:].shape)
    # print('    pdf_mask', states[:, STATE_DROPOUT_BEGIN:].shape)

    pdf = pdf * (1 - cfg.exploration) + cfg.exploration * 1.0 / len(filters)
    # pdf = tf.to_float(is_train) * tf.concat([pdf[:, :1], pdf[:, 1:] * states[:, STATE_DROPOUT_BEGIN:]], axis=1) \
    # + (1.0 - tf.to_float(is_train)) * pdf
    pdf = pdf / (tf.reduce_sum(pdf, axis=1, keep_dims=True) + 1e-30)
    entropy = -pdf * tf.log(pdf)
    entropy = tf.reduce_sum(entropy, axis=1)[:, None]
    print('    pdf:', pdf.shape)
    print('    entropy:', entropy.shape)
    print('    selection_noise:', selection_noise.shape)
    random_filter_id = pdf_sample(pdf, selection_noise)
    max_filter_id = tf.cast(tf.argmax(pdf, axis=1), tf.int32)
    selected_filter_id = is_train * random_filter_id + (
        1 - is_train) * max_filter_id
    print('    selected_filter_id:', selected_filter_id.shape)
    filter_one_hot = tf.one_hot(
        selected_filter_id, depth=len(filters), dtype=tf.float32)
    print('    filter one_hot', filter_one_hot.shape)
    surrogate = tf.reduce_sum(
        filter_one_hot * tf.log(pdf + 1e-10), axis=1, keep_dims=True)

  net = tf.reduce_sum(
      filtered_images * filter_one_hot[:, :, None, None, None], axis=1)
  if high_res is not None:
    high_res_outputs = tf.stack(values=high_res_outputs, axis=1)
    high_res_output = tf.reduce_sum(
        high_res_outputs * filter_one_hot[:, :, None, None, None], axis=1)

  # only the first image will get debug_info
  debug_info = {
      'state': states,
      'selected_filter_id': selected_filter_id[0],
      'filter_debug_info': filter_debug_info,
      'pdf': pdf[0]
  }

  # Combined: Three in one 64x64 ?
  #           otherwise returns pdf, detail, mask
  def debugger(debug_info, combined=True):
    size = 8
    img = None
    images = [None for i in range(3)]
    for i, filter in enumerate(filters):
      selected = i == debug_info['selected_filter_id']
      if selected:
        img = filter.visualize_mask(debug_info['filter_debug_info'][i],
                                    (64, 64)) * 0.8
    assert img is not None
    if not combined:
      # Mask
      images[2] = img.copy()
      # reset img
      img = img * 0 + 0.5

    c = 0
    for i, filter in enumerate(filters):
      pdf = debug_info['pdf'][i]
      if pdf < 1e-10:
        continue
      else:
        c += 1
      selected = i == debug_info['selected_filter_id']
      if selected:
        filter.visualize_filter(debug_info['filter_debug_info'][i], img)
    if not combined:
      # detail
      images[1] = img.copy()
      # reset img
      img = img * 0 + 0.5
    c = 0
    for i, filter in enumerate(filters):
      per_col = 4
      x = c // per_col * 30
      y = size * (c % per_col + 1)
      pdf = debug_info['pdf'][i]
      if pdf < 1e-10:
        continue
      else:
        c += 1
      cv2.putText(img,
                  filter.get_short_name(), (x + 6, y + 4),
                  cv2.FONT_HERSHEY_SIMPLEX, 0.233, (255, 255, 255))
      selected = i == debug_info['selected_filter_id']
      color = 1.0 if selected else 0.3
      width = int(pdf * 20)
      height = 0.35
      corners = [(x + 16, int(y + (1 - height) * size // 2)),
                 (x + 16 + width, int(y + (1 + height) * size // 2))]
      cv2.rectangle(img, (corners[0][0] - 1, corners[0][1] - 1),
                    (corners[1][0] + 1, corners[1][1] + 1), (1, 1,
                                                             1), cv2.FILLED)
      cv2.rectangle(img, corners[0], corners[1], (color, 0.3, 0.3), cv2.FILLED)
    if not combined:
      # pdf
      images[0] = img.copy()

    if combined:
      return img
    else:
      return images

  debugger.width = int(net.shape[1])
  print('    surrogate: ', surrogate.shape)

  # Calculate new states
  new_states = [None for _ in range(STATE_DROPOUT_BEGIN + 1)]
  is_last_step = tf.cast(
      tf.abs(states[:, STATE_STEP_DIM:STATE_STEP_DIM + 1] + 1 - cfg.test_steps)
      < 1e-4,
      dtype=tf.float32)
  submitted = is_last_step

  new_states[STATE_REWARD_DIM] = submitted
  new_states[STATE_STOPPED_DIM] = submitted
  # Increment the step
  new_states[STATE_STEP_DIM] = (states[:, STATE_STEP_DIM] + 1)[:, None]

  # Update filter usage
  filter_usage = states[:, STATE_STEP_DIM + 1:]
  print('usage v.s. onehot', filter_usage.shape, filter_one_hot.shape)
  assert len(filter_usage.shape) == len(filter_one_hot.shape)

  regular_filter_start = 0

  # Penalize submission action that is not the final action.
  early_stop_penalty = (1 - is_last_step) * submitted * cfg.early_stop_penalty

  usage_penalty = tf.reduce_sum(
      filter_usage * filter_one_hot[:, regular_filter_start:],
      axis=1,
      keep_dims=True)
  new_filter_usage = tf.maximum(filter_usage,
                                filter_one_hot[:, regular_filter_start:])
  new_states[STATE_STEP_DIM + 1] = new_filter_usage

  print(submitted.shape, new_states[STATE_STEP_DIM].shape)
  new_states = tf.concat(new_states, axis=1)
  print('new_states:', new_states.shape)

  if cfg.clamp:
    net = tf.clip_by_value(net, 0.0, 5.0)

  entropy_penalty = (1.0 - progress) * cfg.exploration_penalty * (
      -entropy + math.log(len(filters)))

  # Will be substracted from award
  penalty = tf.reduce_mean(
      tf.maximum(net - 1, 0)**2, axis=(1, 2, 3)
  )[:,
    None] + entropy_penalty + usage_penalty * cfg.filter_usage_penalty + early_stop_penalty

  print('states, new_states:', states.shape, new_states.shape)
  print('penalty:', penalty.shape)

  if high_res is None:
    return (net, new_states, surrogate, penalty), debug_info, debugger
  else:
    return (net, new_states, high_res_output), debug_info, debugger