import fenics
import fenics_adjoint
import numpy as np


def fenics_to_numpy(fenics_var):
    """Convert FEniCS variable to numpy array"""
    if isinstance(fenics_var, (fenics.Constant, fenics_adjoint.Constant)):
        return fenics_var.values()

    if isinstance(fenics_var, (fenics.Function, fenics_adjoint.Constant)):
        np_array = fenics_var.vector().get_local()
        n_sub = fenics_var.function_space().num_sub_spaces()
        # Reshape if function is multi-component
        if n_sub != 0:
            np_array = np.reshape(np_array, (len(np_array) // n_sub, n_sub))
        return np_array

    if isinstance(fenics_var, fenics.GenericVector):
        return fenics_var.get_local()

    if isinstance(fenics_var, fenics_adjoint.AdjFloat):
        return np.array(float(fenics_var), dtype=np.float_)

    raise ValueError('Cannot convert ' + str(type(fenics_var)))


def numpy_to_fenics(numpy_array, fenics_var_template):
    """Convert numpy array to FEniCS variable"""
    if isinstance(fenics_var_template, (fenics.Constant, fenics_adjoint.Constant)):
        if numpy_array.shape == (1,):
            return type(fenics_var_template)(numpy_array[0])
        else:
            return type(fenics_var_template)(numpy_array)

    if isinstance(fenics_var_template, (fenics.Function, fenics_adjoint.Function)):
        np_n_sub = numpy_array.shape[-1]
        np_size = np.prod(numpy_array.shape)

        function_space = fenics_var_template.function_space()

        u = type(fenics_var_template)(function_space)
        fenics_size = u.vector().local_size()
        fenics_n_sub = function_space.num_sub_spaces()

        if (fenics_n_sub != 0 and np_n_sub != fenics_n_sub) or np_size != fenics_size:
            err_msg = 'Cannot convert numpy array to Function:' \
                      ' Wrong shape {} vs {}'.format(numpy_array.shape, u.vector().get_local().shape)
            raise ValueError(err_msg)

        if numpy_array.dtype != np.float_:
            err_msg = 'The numpy array must be of type {}, ' \
                      'but got {}'.format(np.float_, numpy_array.dtype)
            raise ValueError(err_msg)

        u.vector().set_local(np.reshape(numpy_array, fenics_size))
        u.vector().apply('insert')
        return u

    if isinstance(fenics_var_template, fenics_adjoint.AdjFloat):
        return fenics_adjoint.AdjFloat(numpy_array)

    err_msg = 'Cannot convert numpy array to {}'.format(fenics_var_template)
    raise ValueError(err_msg)