from pytorch_complex_tensor import ComplexScalar, ComplexGrad

import inspect
import numpy as np
import torch
import re


"""
Complex tensor support for PyTorch.

Uses a regular tensor where the first half are the real numbers and second are the imaginary.

Supports only some basic operations without breaking the gradients for complex math.

Supported ops:
1. addition 
    - (tensor, scalar). Both complex and real.
2. subtraction 
    - (tensor, scalar). Both complex and real.
3. multiply
    - (tensor, scalar). Both complex and real.
4. mm (matrix multiply)
    - (tensor). Both complex and real.
5. abs (absolute value)
6. all indexing ops.
7. t (transpose)

>> c = ComplexTensor(10, 20)

>> #  do regular tensor ops now
>> c = c * 4
>> c = c.mm(c.t())
>> print(c.shape, c.size(0))
>> print(c)
>> print(c[0:1, 1:-1])
"""


class ComplexTensor(torch.Tensor):

    @staticmethod
    def __new__(cls, x, *args, **kwargs):
        # requested to init with dim list
        # double the second to last dim (..., 1, 3, 2) -> (..., 1, 6, 2)

        # reformat complex numpy arrays so we can init with them
        if isinstance(x, np.ndarray) and 'complex' in str(x.dtype):
            # collapse second to last dim
            r = x.real
            i = x.imag
            x = np.concatenate([r, i], axis=-2)

        # x is the second to last dim in this case
        if type(x) is int and len(args) == 1:
            x = x * 2

        elif len(args) >= 2:
            size_args = list(args)
            size_args[-2] *= 2
            args = tuple(size_args)

        else:
            if isinstance(x, torch.Tensor):
                s = x.size()[-2]
            elif isinstance(x, list):
                s = len(x)
            elif isinstance(x, np.ndarray):
                s = x.shape[-2]
            if not (s % 2 == 0): raise Exception('second to last dim must be even. ComplexTensor is 2 real matrices under the hood')

        # init new t
        new_t = super().__new__(cls, x, *args, **kwargs)
        return new_t

    def __deepcopy__(self, memo):
        if not self.is_leaf:
            raise RuntimeError("Only Tensors created explicitly by the user "
                               "(graph leaves) support the deepcopy protocol at the moment")
        if id(self) in memo:
            return memo[id(self)]
        with torch.no_grad():
            if self.is_sparse:
                new_tensor = self.clone()

                # hack tensor to cast as complex
                new_tensor.__class__ = ComplexTensor
            else:
                new_storage = self.storage().__deepcopy__(memo)
                new_tensor = self.new()

                # hack tensor to cast as complex
                new_tensor.__class__ = ComplexTensor
                new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
            memo[id(self)] = new_tensor
            new_tensor.requires_grad = self.requires_grad
            return new_tensor

    @property
    def real(self):
        size = self.size()
        n = size[-2]
        n = n * 2
        result = self[..., :n//2, :]
        return result

    @property
    def imag(self):
        size = self.size()
        n = size[-2]
        n = n * 2
        result = self[..., n//2:, :]
        return result

    def __graph_copy__(self, real, imag):
        # return tensor copy but maintain graph connections
        # force the result to be a ComplexTensor
        result = torch.cat([real, imag], dim=0)
        result.__class__ = ComplexTensor
        return result

    def __graph_copy_scalar__(self, real, imag):
        # return tensor copy but maintain graph connections
        # force the result to be a ComplexTensor
        result = torch.stack([real, imag], dim=-2)
        result.__class__ = ComplexScalar
        return result

    def __add__(self, other):
        """
        Handles scalar (real, complex) and tensor (real, complex) addition
        :param other:
        :return:
        """
        real = self.real
        imag = self.imag

        # given a real tensor
        if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
            real = real + other

        # given a complex tensor
        elif type(other) is ComplexTensor:
            real = real + other.real
            imag = imag + other.imag

        # given a real scalar
        elif np.isreal(other):
            real = real + other

        # given a complex scalar
        else:
            real = real + other.real
            imag = imag + other.imag

        return self.__graph_copy__(real, imag)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        """
        Handles scalar (real, complex) and tensor (real, complex) addition
        :param other:
        :return:
        """
        real = self.real
        imag = self.imag

        # given a real tensor
        if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
            real = real - other

        # given a complex tensor
        elif type(other) is ComplexTensor:
            real = real - other.real
            imag = imag - other.imag

        # given a real scalar
        elif np.isreal(other):
            real = real - other

        # given a complex scalar
        else:
            real = real - other.real
            imag = imag - other.imag

        return self.__graph_copy__(real, imag)

    def __rsub__(self, other):
        return self.__sub__(other)

    def __mul__(self, other):
        """
        Handles scalar (real, complex) and tensor (real, complex) multiplication
        :param other:
        :return:
        """
        real = self.real.clone()
        imag = self.imag.clone()

        # given a real tensor
        if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
            real = real * other
            imag = imag * other

        # given a complex tensor
        elif type(other) is ComplexTensor:
            ac = real * other.real
            bd = imag * other.imag
            ad = real * other.imag
            bc = imag * other.real
            real = ac - bd
            imag = ad + bc

        # given a real scalar
        elif np.isreal(other):
            real = real * other
            imag = imag * other

        # given a complex scalar
        else:
            ac = real * other.real
            bd = imag * other.imag
            ad = real * other.imag
            bc = imag * other.real
            real = ac - bd
            imag = ad + bc

        return self.__graph_copy__(real, imag)

    def __truediv__(self, other):
        real = self.real.clone()
        imag = self.imag.clone()

        # given a real tensor
        if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
            raise NotImplementedError

        # given a complex tensor
        elif type(other) is ComplexTensor:
            raise NotImplementedError

        # given a real scalar
        elif np.isreal(other):
            real = real / other
            imag = imag / other

        # given a complex scalar
        else:
            raise NotImplementedError

        return self.__graph_copy__(real, imag)

    def __rmul__(self, other):
        return self.__mul__(other)

    def __neg__(self):
        return self.__mul__(-1)

    def mm(self, other):
        """
        Handles tensor (real, complex) matrix multiply
        :param other:
        :return:
        """
        real = self.real.clone()
        imag = self.imag.clone()

        # given a real tensor
        if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
            real = real.mm(other)
            imag = imag.mm(other)

        # given a complex tensor
        elif type(other) is ComplexTensor:
            ac = real.mm(other.real)
            bd = imag.mm(other.imag)
            ad = real.mm(other.imag)
            bc = imag.mm(other.real)
            real = ac - bd
            imag = ad + bc

        return self.__graph_copy__(real, imag)

    def t(self):
        real = self.real.t()
        imag = self.imag.t()

        return self.__graph_copy__(real, imag)

    def abs(self):
        result = torch.sqrt(self.real**2 + self.imag**2)
        return result

    def sum(self, *args):
        real_sum = self.real.sum(*args)
        imag_sum = self.imag.sum(*args)
        return ComplexScalar(real_sum, imag_sum)

    def mean(self, *args):
        real_mean = self.real.mean(*args)
        imag_mean = self.imag.mean(*args)
        return ComplexScalar(real_mean, imag_mean)

    @property
    def grad(self):
        g = self._grad
        g.__class__ = ComplexGrad

        return g

    def cuda(self):
        real = self.real.cuda()
        imag = self.imag.cuda()

        return self.__graph_copy__(real, imag)

    def __repr__(self):
        real = self.real.flatten()
        imag = self.imag.flatten()

        # use numpy to print for us
        # strings = np.asarray([f'({a}{"+" if b > 0 else "-"}{abs(b)}j)' for a, b in zip(real, imag)])
        strings = np.asarray([complex(a,b) for a, b in zip(real, imag)]).astype(np.complex64)
        strings = strings.__repr__()
        strings = re.sub('array', 'tensor', strings)
        return strings

    def __str__(self):
        return self.__repr__()

    def is_complex(self):
        return True

    def size(self, *args):
        size = self.data.size(*args)
        size = list(size)
        size[-2] = size[-2] // 2
        size = torch.Size(size)
        return size

    @property
    def shape(self):
        size = self.data.shape
        size = list(size)
        size[-2] = size[-2] // 2
        size = torch.Size(size)
        return size

    def __getitem__(self, item):

        # when real or imag is the caller return regular tensor
        curframe = inspect.currentframe()
        calframe = inspect.getouterframes(curframe, 2)
        caller = calframe[1][3]

        if caller == 'real' or caller == 'imag':
            return super(ComplexTensor, self).__getitem__(item)

        # this is a regular index op, select the requested pairs then form a new ComplexTensor
        r = self.real[item]
        c = self.imag[item]

        return self.__graph_copy__(r, c)