import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils.dif_fms import fast_dif_fms
from .aggregators.AnyNet import AnyNetAggregator

class AnyNetProcessor(nn.Module):
    """
    An implementation of cost procession in AnyNet

    Inputs:
        stage, (str): 'init_guess', the coarsest disparity estimation,
                      'warp_level_8', refine the disparity estimation with feature warp at resolution=1/8
                      'warp_level_4', refine the disparity estimation with feature warp at resolution=1/4
        left, (tensor): Left image feature, in [BatchSize, Channels, Height, Width] layout
        right, (tensor): Right image feature, in [BatchSize, Channels, Height, Width] layout
        disp, (tensor): Disparity map outputted from last stage, in [BatchSize, 1, Height, Width] layout

    Outputs:
        cost_volume (tuple of Tensor): cost volume
            in [BatchSize, MaxDisparity, Height, Width] layout

    """

    def __init__(self, cfg):
        super(AnyNetProcessor, self).__init__()
        self.cfg = cfg.copy()
        self.batch_norm = cfg.model.batch_norm

        self.stage = self.cfg.model.stage

        # cost computation parameters, dict
        self.max_disp = self.cfg.model.cost_processor.cost_computation.max_disp
        self.start_disp = self.cfg.model.cost_processor.cost_computation.start_disp
        self.dilation = self.cfg.model.cost_processor.cost_computation.dilation


        # cost aggregation
        self.aggregator_type = self.cfg.model.cost_processor.cost_aggregator.type
        self.aggregator = nn.ModuleDict()
        for st in self.stage:
            self.aggregator[st] = AnyNetAggregator(
                in_planes=self.cfg.model.cost_processor.cost_aggregator.in_planes[st],
                agg_planes=self.cfg.model.cost_processor.cost_aggregator.agg_planes[st],
                num=self.cfg.model.cost_processor.cost_aggregator.num,
                batch_norm=self.batch_norm,
            )

    def forward(self, stage, left, right, disp=None):
        B, C, H, W = left.shape
        # construct the raw cost volume

        end_disp = self.start_disp[stage] + self.max_disp[stage] - 1

        # disparity sample number
        D = (self.max_disp[stage] + self.dilation[stage] - 1) // self.dilation[stage]

        # generate disparity samples, in [B, D, H, W] layout
        disp_sample = torch.linspace(self.start_disp[stage], end_disp, D)
        disp_sample = disp_sample.view(1, D, 1, 1).expand(B, D, H, W).to(left.device).float()

        # if initial disparity guessed, used for warping
        if disp is not None:
            # up-sample disparity map to the size of left
            H, W = left.shape[-2:]
            scale = W / disp.shape[-1]
            disp = F.interpolate(disp * scale, size=(H, W), mode='bilinear', align_corners=False)
            # shift the disparity sample to be centered at the given disparity map
            disp_sample = disp_sample + disp

        # [B, C, D, H, W]
        raw_cost = fast_dif_fms(left, right, disp_sample=disp_sample)

        # list [[B, D, H, W]]
        cost = self.aggregator[stage](raw_cost)

        return cost