import torch
import torch.nn as nn
import torch.nn.functional as F


def valid_avg_pool(tensor, valid_mask, kernel_size):
    valid_mask = valid_mask.float()

    N, C, H, W = tensor.shape
    out_H = H // kernel_size
    out_W = W // kernel_size

    tensor_patch = F.unfold(
        tensor,
        kernel_size=kernel_size,
        stride=kernel_size
    ).view(N, C, -1, out_H, out_W)

    valid_mask_patch = F.unfold(
        valid_mask,
        kernel_size=kernel_size,
        stride=kernel_size
    ).view(N, C, -1, out_H, out_W)

    count = torch.sum(valid_mask_patch.float(), dim=2)
    pooled_tensor = torch.sum(tensor_patch * valid_mask_patch.float(), dim=2) / \
                    torch.where(torch.le(count, 1e-5), torch.full(count.shape, 1e6).to(tensor.device), count)   # (N, 3, out_H, out_W)

    pooled_mask = torch.gt(count, 1e-3)

    return pooled_tensor, pooled_mask[:, 0, :, :]


def extract_points_validpool(scene_coords, scene_valid_mask, conv_feat, pool_kernel_size=2, point_chw_order=False):
    avg_scene_coords, pooled_mask = valid_avg_pool(scene_coords, scene_valid_mask, kernel_size=pool_kernel_size)
    if not point_chw_order:
        avg_scene_coords = avg_scene_coords.permute(0, 2, 3, 1)
        N, H, W, D = avg_scene_coords.shape
    else:
        N, D, H, W = avg_scene_coords.shape

    assert H == conv_feat.shape[2]
    assert W == conv_feat.shape[3]

    avg_scene_coords = avg_scene_coords.view((N, 3, H * W)) if point_chw_order else avg_scene_coords.view(
        (N, H * W, 3))
    C = conv_feat.shape[1]
    conv_feat = conv_feat.view((N, C, H * W))
    pooled_mask = pooled_mask.view((N, H * W))

    return avg_scene_coords, conv_feat, pooled_mask


def extract_points(scene_coords, conv_feat, pool_kernel_size=2, point_chw_order=False):
    avg_scene_coords = F.avg_pool2d(scene_coords, kernel_size=pool_kernel_size).detach()
    if not point_chw_order:
        avg_scene_coords = avg_scene_coords.permute(0, 2, 3, 1)
        N, H, W, D = avg_scene_coords.shape
    else:
        N, D, H, W = avg_scene_coords.shape

    assert H == conv_feat.shape[2]
    assert W == conv_feat.shape[3]

    avg_scene_coords = avg_scene_coords.view((N, 3, H * W)) if point_chw_order else avg_scene_coords.view(
        (N, H * W, 3))
    C = conv_feat.shape[1]
    conv_feat = conv_feat.view((N, C, H * W))

    return avg_scene_coords, conv_feat