import numpy as np
from pyqmc import pbc


def sherman_morrison_row(e, inv, vec):
    ratio = np.einsum("ij,ij->i", vec, inv[:, :, e])
    tmp = np.einsum("ek,ekj->ej", vec, inv)
    invnew = (
        inv
        - np.einsum("ki,kj->kij", inv[:, :, e], tmp) / ratio[:, np.newaxis, np.newaxis]
    )
    invnew[:, :, e] = inv[:, :, e] / ratio[:, np.newaxis]
    return ratio, invnew


_gldict = {"laplacian": np.s_[:1], "gradient_laplacian": np.s_[0:4]}


def _aostack_mol(ao, gl):
    return np.concatenate(
        [ao[_gldict[gl]], ao[[4, 7, 9]].sum(axis=0, keepdims=True)], axis=0
    )


def _aostack_pbc(ao, gl):
    return [_aostack_mol(ak, gl) for ak in ao]


def get_kinds(cell, mf, kpts, tol=1e-6):
    """Given a list of kpts, return inds such that mf.kpts[inds] is a list of kpts equivalent to the input list"""
    kdiffs = mf.kpts[np.newaxis] - kpts[:, np.newaxis]
    frac_kdiffs = np.dot(kdiffs, cell.lattice_vectors().T) / (2 * np.pi)
    kdiffs = np.mod(frac_kdiffs + 0.5, 1) - 0.5
    return np.nonzero(np.linalg.norm(kdiffs, axis=-1) < tol)[1]


class PySCFSlater:
    """A wave function object has a state defined by a reference configuration of electrons.
    The functions recompute() and updateinternals() change the state of the object, and 
    the rest compute and return values from that state. """

    def __init__(self, mol, mf, twist=None):
        """
        Inputs:
          supercell: object returned by get_supercell(cell, S)
          mf: scf object of primitive cell calculation. scf calculation must include k points that fold onto the gamma point of the supercell
          twist: (3,) array, twisted boundary condition in fractional coordinates, i.e. as coefficients of the reciprocal lattice vectors of the supercell. Integer values are equivalent to zero.
        """
        self.parameters = {}
        self.real_tol = 1e4
        self._coefflookup = ("mo_coeff_alpha", "mo_coeff_beta")

        if hasattr(mol, "a"):
            self._init_pbc(mol, mf, twist)
        else:
            self._init_mol(mol, mf)
        self.pbc_str = "PBC" if hasattr(mol, "a") else ""
        self._aostack = _aostack_pbc if hasattr(mol, "a") else _aostack_mol

        self.dtype = complex if self.iscomplex else float
        if self.iscomplex:
            self.get_phase = lambda x: x / np.abs(x)
            self.get_wrapphase = lambda x: np.exp(1j * x)
        else:
            self.get_phase = np.sign
            self.get_wrapphase = lambda x: (-1) ** np.round(x / np.pi)

    def _init_mol(self, mol, mf):
        from pyscf import scf

        for s, lookup in enumerate(self._coefflookup):
            if len(mf.mo_occ.shape) == 2:
                self.parameters[lookup] = mf.mo_coeff[s][
                    :, np.asarray(mf.mo_occ[s] > 0.9)
                ]
            else:
                minocc = (0.9, 1.1)[s]
                self.parameters[lookup] = mf.mo_coeff[:, np.asarray(mf.mo_occ > minocc)]
        self._nelec = tuple(mol.nelec)
        self._mol = mol
        self.iscomplex = bool(sum(map(np.iscomplexobj, self.parameters.values())))
        self.evaluate_orbitals = self._evaluate_orbitals_mol
        self.evaluate_mos = self._evaluate_mos_mol

    def _init_pbc(self, cell, mf, twist):
        from pyscf.pbc import scf
        from pyqmc.supercell import get_supercell_kpts

        # Make sure supercell has attributes S and original_cell
        for attribute in ["original_cell", "S", "scale"]:
            if not hasattr(cell, attribute):
                print('Warning: supercell is missing attribute "%s"' % attribute)
                print("setting original_cell=supercell and S=np.eye(3)")
                cell.original_cell = cell
                cell.S = np.eye(3)
                cell.scale = 1
        self.supercell = cell
        self._cell = cell.original_cell

        # Define kpts
        if twist is None:
            twist = np.zeros(3)
        else:
            twist = np.dot(np.linalg.inv(cell.a), np.mod(twist, 1.0)) * 2 * np.pi
        self.kinds = get_kinds(self._cell, mf, get_supercell_kpts(cell) + twist)
        self._kpts = mf.kpts[self.kinds]
        assert len(self.kinds) == len(self._kpts), (self._kpts, mf.kpts)
        self.nk = len(self._kpts)
        self.iscomplex = bool(sum(map(np.iscomplexobj, self.parameters.values())))
        self.iscomplex = self.iscomplex or np.linalg.norm(self._kpts) > 1e-12

        # Define parameters
        self.param_split = {}
        for s, lookup in enumerate(self._coefflookup):
            mclist = []
            for kind in self.kinds:
                if len(mf.mo_coeff[0][0].shape) == 2:
                    mca = mf.mo_coeff[s][kind][:, np.asarray(mf.mo_occ[s][kind] > 0.9)]
                else:
                    minocc = (0.9, 1.1)[s]
                    mca = mf.mo_coeff[kind][:, np.asarray(mf.mo_occ[kind] > minocc)]
                mca = np.real_if_close(mca, tol=self.real_tol)
                mclist.append(mca / np.sqrt(self.nk))
            self.param_split[lookup] = np.cumsum([m.shape[1] for m in mclist])
            self.parameters[lookup] = np.concatenate(mclist, axis=-1)

        # Define nelec
        if isinstance(mf, scf.kuhf.KUHF):
            # Then indices are (spin, kpt, basis, mo)
            self._nelec = [int(np.sum([o[k] for k in self.kinds])) for o in mf.mo_occ]
        elif isinstance(mf, scf.khf.KRHF):
            # Then indices are (kpt, basis, mo)
            self._nelec = [
                int(np.sum([mf.mo_occ[k] > t for k in self.kinds])) for t in (0.9, 1.1)
            ]
        else:
            print("Warning: PySCFSlater not expecting scf object of type", type(mf))
            scale = self.supercell.scale
            self._nelec = [int(np.round(n * scale)) for n in self._cell.nelec]
        self._nelec = tuple(self._nelec)

        self.evaluate_orbitals = self._evaluate_orbitals_pbc
        self.evaluate_mos = self._evaluate_mos_pbc

    def _evaluate_orbitals_mol(self, configs, mask=None, eval_str="GTOval_sph"):
        mycoords = configs.configs if mask is None else configs.configs[mask]
        mycoords = mycoords.reshape((-1, mycoords.shape[-1]))
        return self._mol.eval_gto(eval_str, mycoords)

    def _evaluate_mos_mol(self, ao, s):
        return ao.dot(self.parameters[self._coefflookup[s]])

    def _evaluate_orbitals_pbc(self, configs, mask=None, eval_str="GTOval_sph"):
        mycoords = configs.configs
        configswrap = configs.wrap
        if mask is not None:
            mycoords = mycoords[mask]
            configswrap = configswrap[mask]
        mycoords = mycoords.reshape((-1, mycoords.shape[-1]))
        # wrap supercell positions into primitive cell
        prim_coords, prim_wrap = pbc.enforce_pbc(self._cell.lattice_vectors(), mycoords)
        configswrap = configswrap.reshape(prim_wrap.shape)
        wrap = prim_wrap + np.dot(configswrap, self.supercell.S)
        kdotR = np.linalg.multi_dot(
            (self._kpts, self._cell.lattice_vectors().T, wrap.T)
        )
        wrap_phase = self.get_wrapphase(kdotR)
        # evaluate AOs for all electron positions
        ao = self._cell.eval_gto("PBC" + eval_str, prim_coords, kpts=self._kpts)
        ao = [ao[k] * wrap_phase[k][:, np.newaxis] for k in range(self.nk)]
        return ao

    def _evaluate_mos_pbc(self, aos, s):
        """
        Evaluate MOs for spin s given aos
        """
        c = self._coefflookup[s]
        p = np.split(self.parameters[c], self.param_split[c], axis=-1)
        mo = [ao.dot(p[k]) for k, ao in enumerate(aos)]
        return np.concatenate(mo, axis=-1)

    def recompute(self, configs):
        """This computes the value from scratch. Returns the logarithm of the wave function as
        (phase,logdet). If the wf is real, phase will be +/- 1."""
        nconf, nelec, ndim = configs.configs.shape
        aos = self.evaluate_orbitals(configs)
        if hasattr(self, "nk"):
            aos_shape = (self.nk, nconf, nelec, -1)
        else:
            aos_shape = (1, nconf, nelec, -1)
        aos = np.reshape(aos, aos_shape)
        self._aovals = aos
        self._dets = []
        self._inverse = []
        for s in [0, 1]:
            i0, i1 = s * self._nelec[0], self._nelec[0] + s * self._nelec[1]
            ne = self._nelec[s]
            mo = self.evaluate_mos(aos[:, :, i0:i1], s).reshape(nconf, ne, ne)
            phase, mag = np.linalg.slogdet(mo)
            self._dets.append((phase, mag))
            self._inverse.append(np.linalg.inv(mo))

        return self.value()

    def updateinternals(self, e, epos, mask=None):
        s = int(e >= self._nelec[0])
        if mask is None:
            mask = [True] * epos.configs.shape[0]
        eeff = e - s * self._nelec[0]
        aos = self.evaluate_orbitals(epos)
        self._aovals[:, :, e, :] = np.asarray(aos)  # (kpt, config, ao)
        mo = self.evaluate_mos(aos, s).reshape(len(mask), -1)
        ratio, self._inverse[s][mask, :, :] = sherman_morrison_row(
            eeff, self._inverse[s][mask, :, :], mo[mask, :]
        )
        self._updateval(ratio, s, mask)

    def _updateval(self, ratio, s, mask):
        self._dets[s][0][mask] *= self.get_phase(ratio)
        self._dets[s][1][mask] += np.log(np.abs(ratio))

    ### not state-changing functions

    def value(self):
        """Return logarithm of the wave function as noted in recompute()"""
        return self._dets[0][0] * self._dets[1][0], self._dets[0][1] + self._dets[1][1]

    def _testrow(self, e, vec, mask=None, spin=None):
        """vec is a nconfig,nmo vector which replaces row e"""
        s = int(e >= self._nelec[0]) if spin is None else spin
        elec = e - s * self._nelec[0]
        if mask is None:
            return np.einsum("i...j,ij...->i...", vec, self._inverse[s][:, :, elec])

        return np.einsum("i...j,ij...->i...", vec, self._inverse[s][mask][:, :, elec])

    def _testcol(self, i, s, vec):
        """vec is a nconfig,nmo vector which replaces column i"""
        return np.einsum("ij...,ij->i...", vec, self._inverse[s][:, i, :])

    def testvalue(self, e, epos, mask=None):
        """ return the ratio between the current wave function and the wave function if 
        electron e's position is replaced by epos"""
        s = int(e >= self._nelec[0])
        nmask = epos.configs.shape[0] if mask is None else np.sum(mask)
        if nmask == 0:
            return np.zeros((0, epos.configs.shape[1]))
        aos = self.evaluate_orbitals(epos, mask)
        mo = self.evaluate_mos(aos, s)
        mo = mo.reshape(nmask, *epos.configs.shape[1:-1], self._nelec[s])
        return self._testrow(e, mo, mask)

    def testvalue_many(self, e, epos, mask=None):
        """ return the ratio between the current wave function and the wave function if 
        an electron's position is replaced by epos for each electron"""
        s = (e >= self._nelec[0]).astype(int)
        nmask = epos.configs.shape[0] if mask is None else np.sum(mask)
        if nmask == 0:
            return np.zeros((0, epos.configs.shape[1]))

        aos = self.evaluate_orbitals(epos, mask)
        ratios = np.zeros((epos.configs.shape[0], e.shape[0]), dtype=self.dtype)
        for spin in [0, 1]:
            ind = s == spin
            mo = self.evaluate_mos(aos, spin)
            mo = mo.reshape(nmask, *epos.configs.shape[1:-1], self._nelec[spin])
            ratios[:, ind] = self._testrow(e[ind], mo, mask=mask, spin=spin)
        return ratios

    def gradient(self, e, epos):
        """ Compute the gradient of the log wave function 
        Note that this can be called even if the internals have not been updated for electron e,
        if epos differs from the current position of electron e."""
        s = int(e >= self._nelec[0])
        aograd = self.evaluate_orbitals(epos, eval_str="GTOval_sph_deriv1")
        mograd = self.evaluate_mos(aograd, s)
        ratios = np.asarray([self._testrow(e, x) for x in mograd])
        return ratios[1:] / ratios[:1]

    def laplacian(self, e, epos):
        s = int(e >= self._nelec[0])
        ao = self.evaluate_orbitals(epos, eval_str="GTOval_sph_deriv2")
        mo = self.evaluate_mos(self._aostack(ao, "laplacian"), s)
        ratios = np.asarray([self._testrow(e, x) for x in mo])
        return ratios[1] / ratios[0]

    def gradient_laplacian(self, e, epos):
        s = int(e >= self._nelec[0])
        ao = self.evaluate_orbitals(epos, eval_str="GTOval_sph_deriv2")
        mo = self.evaluate_mos(self._aostack(ao, "gradient_laplacian"), s)
        ratios = np.asarray([self._testrow(e, x) for x in mo])
        return ratios[1:-1] / ratios[:1], ratios[-1] / ratios[0]

    def pgradient(self):
        d = {}
        for parm in self.parameters:
            s = int("beta" in parm)
            # Get AOs for our spin channel only
            i0, i1 = s * self._nelec[0], self._nelec[0] + s * self._nelec[1]
            ao = self._aovals[:, :, i0:i1, :]  # (kpt, config, electron, ao)
            pgrad_shape = (ao.shape[-3],) + self.parameters[parm].shape
            pgrad = np.zeros(pgrad_shape, dtype=self.dtype)  # (nconf, coeff)
            # Compute derivatives w.r.t. MO coefficients
            if ao.shape[0] > 1:  # multiple kpts
                split_sizes = np.diff([0] + list(self.param_split[parm]))
                k = np.repeat(np.arange(self.nk), split_sizes)
                for i in range(self._nelec[s]):  # MO loop
                    pgrad[:, :, i] = self._testcol(i, s, ao[k[i]])
            else:
                ao = ao[0]
                for i in range(self._nelec[s]):  # MO loop
                    pgrad[:, :, i] = self._testcol(i, s, ao)
            d[parm] = np.asarray(pgrad)
        return d