import geometry
import torchvision
import util

from pytorch_prototyping import pytorch_prototyping

import torch
from torch import nn


def init_recurrent_weights(self):
    for m in self.modules():
        if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
            for name, param in m.named_parameters():
                if 'weight_ih' in name:
                    nn.init.kaiming_normal_(param.data)
                elif 'weight_hh' in name:
                    nn.init.orthogonal_(param.data)
                elif 'bias' in name:
                    param.data.fill_(0)


def lstm_forget_gate_init(lstm_layer):
    for name, parameter in lstm_layer.named_parameters():
        if not "bias" in name: continue
        n = parameter.size(0)
        start, end = n // 4, n // 2
        parameter.data[start:end].fill_(1.)


def clip_grad_norm_hook(x, max_norm=10):
    total_norm = x.norm()
    total_norm = total_norm ** (1 / 2.)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        return x * clip_coef


class DepthSampler(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,
                xy,
                depth,
                cam2world,
                intersection_net,
                intrinsics):
        self.logs = list()

        batch_size, _, _ = cam2world.shape

        intersections = geometry.world_from_xy_depth(xy=xy, depth=depth, cam2world=cam2world, intrinsics=intrinsics)

        depth = geometry.depth_from_world(intersections, cam2world)

        if self.training:
            print(depth.min(), depth.max())

        return intersections, depth


class Raymarcher(nn.Module):
    def __init__(self,
                 num_feature_channels,
                 raymarch_steps):
        super().__init__()

        self.n_feature_channels = num_feature_channels
        self.steps = raymarch_steps

        hidden_size = 16
        self.lstm = nn.LSTMCell(input_size=self.n_feature_channels,
                                hidden_size=hidden_size)

        self.lstm.apply(init_recurrent_weights)
        lstm_forget_gate_init(self.lstm)

        self.out_layer = nn.Linear(hidden_size, 1)
        self.counter = 0

    def forward(self,
                cam2world,
                phi,
                uv,
                intrinsics):
        batch_size, num_samples, _ = uv.shape
        log = list()

        ray_dirs = geometry.get_ray_directions(uv,
                                               cam2world=cam2world,
                                               intrinsics=intrinsics)

        initial_depth = torch.zeros((batch_size, num_samples, 1)).normal_(mean=0.05, std=5e-4).cuda()
        init_world_coords = geometry.world_from_xy_depth(uv,
                                                         initial_depth,
                                                         intrinsics=intrinsics,
                                                         cam2world=cam2world)

        world_coords = [init_world_coords]
        depths = [initial_depth]
        states = [None]

        for step in range(self.steps):
            v = phi(world_coords[-1])

            state = self.lstm(v.view(-1, self.n_feature_channels), states[-1])

            if state[0].requires_grad:
                state[0].register_hook(lambda x: x.clamp(min=-10, max=10))

            signed_distance = self.out_layer(state[0]).view(batch_size, num_samples, 1)
            new_world_coords = world_coords[-1] + ray_dirs * signed_distance

            states.append(state)
            world_coords.append(new_world_coords)

            depth = geometry.depth_from_world(world_coords[-1], cam2world)

            if self.training:
                print("Raymarch step %d: Min depth %0.6f, max depth %0.6f" %
                      (step, depths[-1].min().detach().cpu().numpy(), depths[-1].max().detach().cpu().numpy()))

            depths.append(depth)

        if not self.counter % 100:
            # Write tensorboard summary for each step of ray-marcher.
            drawing_depths = torch.stack(depths, dim=0)[:, 0, :, :]
            drawing_depths = util.lin2img(drawing_depths).repeat(1, 3, 1, 1)
            log.append(('image', 'raycast_progress',
                        torch.clamp(torchvision.utils.make_grid(drawing_depths, scale_each=False, normalize=True), 0.0,
                                    5),
                        100))

            # Visualize residual step distance (i.e., the size of the final step)
            fig = util.show_images([util.lin2img(signed_distance)[i, :, :, :].detach().cpu().numpy().squeeze()
                                    for i in range(batch_size)])
            log.append(('figure', 'stopping_distances', fig, 100))
        self.counter += 1

        return world_coords[-1], depths[-1], log


class DeepvoxelsRenderer(nn.Module):
    def __init__(self,
                 nf0,
                 in_channels,
                 input_resolution,
                 img_sidelength):
        super().__init__()

        self.nf0 = nf0
        self.in_channels = in_channels
        self.input_resolution = input_resolution
        self.img_sidelength = img_sidelength

        self.num_down_unet = util.num_divisible_by_2(input_resolution)
        self.num_upsampling = util.num_divisible_by_2(img_sidelength) - self.num_down_unet

        self.build_net()

    def build_net(self):
        self.net = [
            pytorch_prototyping.Unet(in_channels=self.in_channels,
                                     out_channels=3 if self.num_upsampling <= 0 else 4 * self.nf0,
                                     outermost_linear=True if self.num_upsampling <= 0 else False,
                                     use_dropout=True,
                                     dropout_prob=0.1,
                                     nf0=self.nf0 * (2 ** self.num_upsampling),
                                     norm=nn.BatchNorm2d,
                                     max_channels=8 * self.nf0,
                                     num_down=self.num_down_unet)
        ]

        if self.num_upsampling > 0:
            self.net += [
                pytorch_prototyping.UpsamplingNet(per_layer_out_ch=self.num_upsampling * [self.nf0],
                                                  in_channels=4 * self.nf0,
                                                  upsampling_mode='transpose',
                                                  use_dropout=True,
                                                  dropout_prob=0.1),
                pytorch_prototyping.Conv2dSame(self.nf0, out_channels=self.nf0 // 2, kernel_size=3, bias=False),
                nn.BatchNorm2d(self.nf0 // 2),
                nn.ReLU(True),
                pytorch_prototyping.Conv2dSame(self.nf0 // 2, 3, kernel_size=3)
            ]

        self.net += [nn.Tanh()]
        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        batch_size, _, ch = input.shape
        input = input.permute(0, 2, 1).view(batch_size, ch, self.img_sidelength, self.img_sidelength)
        out = self.net(input)
        return out.view(batch_size, 3, -1).permute(0, 2, 1)