# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""ODL integration with shearlab."""

from threading import Lock

import matplotlib.pyplot as plt
import numpy as np
from numpy import ceil
from numpy.fft import fft2, fftshift, ifft2, ifftshift

import julia
import odl

__all__ = ('ShearlabOperator',)


class ShearlabOperator(odl.Operator):

    """Shearlet transform using Shearlab.jl as backend.
    This is the non-compact shearlet transform implemented using the Fourier
    transform.
    """

    def __init__(self, space, num_scales):
        """Initialize a new instance.

        Parameters
        ----------
        space : `DiscretizedSpace`
            The space on which the shearlet transform should act. Must be
            two-dimensional.
        num_scales : nonnegative `int`
            The number of scales for the shearlet transform, higher numbers
            mean better edge resolution but more computational burden.

        Examples
        --------
        Create a 2d-shearlet transform:
        >>> space = odl.uniform_discr([-1, -1], [1, 1], [128, 128])
        >>> shearlet_transform = ShearlabOperator(space, num_scales=2)
        """
        self.shearlet_system = getshearletsystem2D(
            space.shape[0], space.shape[1], num_scales)
        range = space ** self.shearlet_system.nShearlets
        self.mutex = Lock()
        super(ShearlabOperator, self).__init__(space, range, True)

    def _call(self, x):
        """``self(x)``."""
        with self.mutex:
            result = sheardec2D(x, self.shearlet_system)
            return np.moveaxis(result, -1, 0)

    @property
    def adjoint(self):
        """The adjoint operator."""
        op = self

        class ShearlabOperatorAdjoint(odl.Operator):

            """Adjoint of the shearlet transform.

            See Also
            --------
            odl.contrib.shearlab.ShearlabOperator
            """

            def __init__(self):
                """Initialize a new instance."""
                self.mutex = op.mutex
                self.shearlet_system = op.shearlet_system
                super(ShearlabOperatorAdjoint, self).__init__(
                    op.range, op.domain, True)

            def _call(self, x):
                """``self(x)``."""
                with op.mutex:
                    x = np.moveaxis(x, 0, -1)
                    return sheardecadjoint2D(x, op.shearlet_system)

            @property
            def adjoint(self):
                """The adjoint operator."""
                return op

            @property
            def inverse(self):
                """The inverse operator."""
                op = self

                class ShearlabOperatorAdjointInverse(odl.Operator):

                    """
                    Adjoint of the inverse/Inverse of the adjoint
                    of shearlet transform.

                    See Also
                    --------
                    odl.contrib.shearlab.ShearlabOperator
                    """

                    def __init__(self):
                        """Initialize a new instance."""
                        self.mutex = op.mutex
                        self.shearlet_system = op.shearlet_system
                        super(ShearlabOperatorAdjointInverse, self).__init__(
                            op.range, op.domain, True)

                    def _call(self, x):
                        """``self(x)``."""
                        with op.mutex:
                            result = shearrecadjoint2D(x, op.shearlet_system)
                            return np.moveaxis(result, -1, 0)

                    @property
                    def adjoint(self):
                        """The adjoint operator."""
                        return op.adjoint.inverse

                    @property
                    def inverse(self):
                        """The inverse operator."""
                        return op

                return ShearlabOperatorAdjointInverse()

        return ShearlabOperatorAdjoint()

    @property
    def inverse(self):
        """The inverse operator."""
        op = self

        class ShearlabOperatorInverse(odl.Operator):

            """Inverse of the shearlet transform.

            See Also
            --------
            odl.contrib.shearlab.ShearlabOperator
            """

            def __init__(self):
                """Initialize a new instance."""
                self.mutex = op.mutex
                self.shearlet_system = op.shearlet_system
                super(ShearlabOperatorInverse, self).__init__(
                    op.range, op.domain, True)

            def _call(self, x):
                """``self(x)``."""
                with op.mutex:
                    x = np.moveaxis(x, 0, -1)
                    return shearrec2D(x, op.shearlet_system)

            @property
            def adjoint(self):
                """The inverse operator."""
                op = self

                class ShearlabOperatorInverseAdjoint(odl.Operator):

                    """
                    Adjoint of the inverse/Inverse of the adjoint
                    of shearlet transform.

                    See Also
                    --------
                    odl.contrib.shearlab.ShearlabOperator
                    """

                    def __init__(self):
                        """Initialize a new instance."""

                        self.mutex = op.mutex
                        self.shearlet_system = op.shearlet_system
                        super(ShearlabOperatorInverseAdjoint, self).__init__(
                            op.range, op.domain, True)

                    def _call(self, x):
                        """``self(x)``."""
                        with op.mutex:
                            result = shearrecadjoint2D(x, op.shearlet_system)
                            return np.moveaxis(result, -1, 0)

                    @property
                    def adjoint(self):
                        """The adjoint operator."""
                        return op

                    @property
                    def inverse(self):
                        """The inverse operator."""
                        return op.inverse.adjoint

                return ShearlabOperatorInverseAdjoint()

            @property
            def inverse(self):
                """The inverse operator."""
                return op

        return ShearlabOperatorInverse()


# Python library for shearlab.jl

def load_julia_with_Shearlab():
    """Function to load Shearlab."""
    # Importing base
    j = julia.Julia()
    j.eval('using Shearlab')
    j.eval('using PyPlot')
    j.eval('using Images')
    return j


j = load_julia_with_Shearlab()


def load_image(name, n, m=None, gpu=None, square=None):
    """Function to load images with certain size."""
    if m is None:
        m = n
    if gpu is None:
        gpu = 0
    if square is None:
        square = 0
    command = ('Shearlab.load_image("{}", {}, {}, {}, {})'.format(name,
               n, m, gpu, square))
    return j.eval(command)


def imageplot(f, str=None, sbpt=None):
    """Plot an image generated by the library."""
    # Function to plot images
    if str is None:
        str = ''
    if sbpt is None:
        sbpt = []
    if sbpt != []:
        plt.subplot(sbpt[0], sbpt[1], sbpt[2])
    imgplot = plt.imshow(f, interpolation='nearest')
    imgplot.set_cmap('gray')
    plt.axis('off')
    if str != '':
        plt.title(str)


class Shearletsystem2D:
    """Class of shearlet system in 2D."""
    def __init__(self, shearlets, size, shearLevels, full, nShearlets,
                 shearletIdxs, dualFrameWeights, RMS, isComplex):
        self.shearlets = shearlets
        self.size = size
        self.shearLevels = shearLevels
        self.full = full
        self.nShearlets = nShearlets
        self.shearletIdxs = shearletIdxs
        self.dualFrameWeights = dualFrameWeights
        self.RMS = RMS
        self.isComplex = isComplex


def getshearletsystem2D(rows, cols, nScales, shearLevels=None,
                        full=None,
                        directionalFilter=None,
                        quadratureMirrorFilter=None):
    """Function to generate de 2D system."""
    if shearLevels is None:
        shearLevels = [float(ceil(i / 2)) for i in range(1, nScales + 1)]
    if full is None:
        full = 0
    if directionalFilter is None:
        directionalFilter = 'Shearlab.filt_gen("directional_shearlet")'
    if quadratureMirrorFilter is None:
        quadratureMirrorFilter = 'Shearlab.filt_gen("scaling_shearlet")'
    j.eval('rows=' + str(rows))
    j.eval('cols=' + str(cols))
    j.eval('nScales=' + str(nScales))
    j.eval('shearLevels=' + str(shearLevels))
    j.eval('full=' + str(full))
    j.eval('directionalFilter=' + directionalFilter)
    j.eval('quadratureMirrorFilter=' + quadratureMirrorFilter)
    j.eval('shearletsystem=Shearlab.getshearletsystem2D(rows, '
           'cols, nScales, shearLevels, full, directionalFilter, '
           'quadratureMirrorFilter) ')
    shearlets = j.eval('shearletsystem.shearlets')
    size = j.eval('shearletsystem.size')
    shearLevels = j.eval('shearletsystem.shearLevels')
    full = j.eval('shearletsystem.full')
    nShearlets = j.eval('shearletsystem.nShearlets')
    shearletIdxs = j.eval('shearletsystem.shearletIdxs')
    dualFrameWeights = j.eval('shearletsystem.dualFrameWeights')
    RMS = j.eval('shearletsystem.RMS')
    isComplex = j.eval('shearletsystem.isComplex')
    j.eval('shearletsystem = 0')
    return Shearletsystem2D(shearlets, size, shearLevels, full, nShearlets,
                            shearletIdxs, dualFrameWeights, RMS, isComplex)


def sheardec2D(X, shearletsystem):
    """Shearlet Decomposition function."""
    coeffs = np.zeros(shearletsystem.shearlets.shape, dtype=complex)
    Xfreq = fftshift(fft2(ifftshift(X)))
    for i in range(shearletsystem.nShearlets):
        coeffs[:, :, i] = fftshift(ifft2(ifftshift(Xfreq * np.conj(
                                   shearletsystem.shearlets[:, :, i]))))
    return coeffs.real


def shearrec2D(coeffs, shearletsystem):
    """Shearlet Recovery function."""
    X = np.zeros(coeffs.shape[:2], dtype=complex)
    for i in range(shearletsystem.nShearlets):
        X = X + fftshift(fft2(
            ifftshift(coeffs[:, :, i]))) * shearletsystem.shearlets[:, :, i]
    return (fftshift(ifft2(ifftshift((
            X / shearletsystem.dualFrameWeights))))).real


def sheardecadjoint2D(coeffs, shearletsystem):
    """Shearlet Decomposition adjoint function."""
    X = np.zeros(coeffs.shape[:2], dtype=complex)
    for i in range(shearletsystem.nShearlets):
        X = X + fftshift(fft2(
            ifftshift(coeffs[:, :, i]))) * np.conj(
            shearletsystem.shearlets[:, :, i])
    return (fftshift(ifft2(ifftshift(
            X / shearletsystem.dualFrameWeights)))).real


def shearrecadjoint2D(X, shearletsystem):
    """Shearlet Recovery adjoint function."""
    coeffs = np.zeros(shearletsystem.shearlets.shape, dtype=complex)
    Xfreq = fftshift(fft2(ifftshift(X)))
    for i in range(shearletsystem.nShearlets):
        coeffs[:, :, i] = fftshift(ifft2(ifftshift(
            Xfreq * shearletsystem.shearlets[:, :, i])))
    return coeffs.real


if __name__ == '__main__':
    from odl.util.testutils import run_doctests
    run_doctests()