# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# 3. Use proper functions provided by PySCF
#   * Switch between df.incore and df.outcore according to system memory
#   *   (Koh: Is there an identical function in outcore? which one, incore or outcore, is used when need of more memory?)
#   * Use get_veff of scf object instead of get_vxc
#   *   (Koh: get_vxc cannot generate correct J,K matrix from complex density matrix)

import numpy as np
import scipy, time
import scipy.linalg
from pyscf import gto, df
from pyscf import lib
from pyscf.lib import logger
from pyscf.scf import diis
import sys

FSPERAU = 0.0241888

def transmat(M,U,inv = 1):
    if inv == 1:
        # U.t() * M * U
        Mtilde = np.dot(np.dot(U.T.conj(),M),U)
    elif inv == -1:
        # U * M * U.t()
        Mtilde = np.dot(np.dot(U,M),U.T.conj())
    return Mtilde

def trdot(A,B):
    C = np.trace(np.dot(A,B))
    return C

def matrixpower(A,p,PrintCondition=False):
    """ Raise a Hermitian Matrix to a possibly fractional power. """
    u,s,v = np.linalg.svd(A)
    if (PrintCondition):
        print("matrixpower: Minimal Eigenvalue =", np.min(s))
    for i in range(len(s)):
        if (abs(s[i]) < np.power(10.0,-14.0)):
            s[i] = np.power(10.0,-14.0)
    return np.dot(u,np.dot(np.diag(np.power(s,p)),v))

class RTTDSCF(lib.StreamObject):
    RT-TDSCF base object.
    Other types of propagations may inherit from this.
    Calling this class starts the propagation
        verbose: int
            Print level.  Default value equals to :class:`ks.verbose`
        conv_tol: float
            converge threshold.  Default value equals to :class:`ks.conv_tol`
        auxbas: str
            auxilliary basis for 2c/3c eri. Default is weigend
        prm: str
            string object with |variable    value| on each line
    Saved results

        output: str
            name of the file with result of propagation

    def __init__(self,ks,prm=None,output = "log.dat", auxbas = "weigend"):
        self.stdout = sys.stdout
        self.verbose = ks.verbose
        self.enuc = ks.energy_nuc()
        self.conv_tol = ks.conv_tol
        self.auxbas = auxbas
        self.hyb = ks._numint.hybrid_coeff(ks.xc, spin=(ks.mol.spin>0)+1)
        self.adiis = None
        self.ks  = ks
        self.eri3c = None
        self.eri2c = None
        self.s = ks.mol.intor_symmetric('int1e_ovlp')
        self.x = matrixpower(self.s,-1./2.)
        self._keys = set(self.__dict__.keys())

        fmat, c_am, v_lm, rho = self.initialcondition(prm)
        start = time.time()
        self.prop(fmat, c_am, v_lm, rho, output)
        end = time.time()
        logger.info(self,"Propagation time: %f", end-start)

        logger.warn(self, 'RT-TDSCF is an experimental feature. It is '
                    'still in testing.\nFeatures and APIs may be changed '
                    'in the future.')

    def auxmol_set(self, mol, auxbas = "weigend"):
        Generate 2c/3c electron integral (eri2c,eri3c)
        Generate ovlp matrix (S), and AO to Lowdin AO matrix transformation matrix (X)

            mol: Mole class
                Default is ks.mol

            auxbas: str
                auxilliary basis for 2c/3c eri. Default is weigend

            eri3c: float
                3 center eri. shape: (AO,AO,AUX)
            eri2c: float
                2 center eri. shape: (AUX,AUX)

        auxmol = gto.Mole()
        auxmol.atom = mol.atom
        auxmol.basis = auxbas
        self.auxmol = auxmol
        atm, bas, env = gto.conc_env(mol._atm, mol._bas, mol._env, auxmol._atm,\
        auxmol._bas, auxmol._env)
        eri3c = df.incore.aux_e2(mol, auxmol, intor="cint3c2e_sph", aosym="s1",\
        comp=1 )
        eri2c = df.incore.fill_2c2e(mol,auxmol)
        self.eri3c = eri3c.copy()
        self.eri2c = eri2c.copy()
        return eri3c, eri2c

    def fockbuild(self,dm_lao,it = -1):
        Updates Fock matrix

            dm_lao: float or complex
                Lowdin AO density matrix.
            it: int
                iterator for SCF DIIS

            fmat: float or complex
                Fock matrix in Lowdin AO basis
            jmat: float or complex
                Coulomb matrix in AO basis
            kmat: float or complex
                Exact Exchange in AO basis
        if self.params["Model"] == "TDHF":
            Pt = 2.0*transmat(dm_lao,self.x,-1)
            jmat,kmat = self.get_jk(Pt)
            veff = 0.5*(jmat+jmat.T.conj()) - 0.5*(0.5*(kmat + kmat.T.conj()))
            if self.adiis and it > 0:
                return transmat(self.adiis.update(self.s,Pt,self.h + veff),\
                self.x), jmat, kmat
                return  transmat(self.h + veff,self.x), jmat, kmat
        elif self.params["Model"] == "TDDFT":
            Pt = 2 * transmat(dm_lao,self.x,-1)
            jmat = self.J = self.get_j(Pt)
            Veff = self.J.astype(complex)
            Vxc, excsum, kmat = self.get_vxc(Pt)
            Veff += Vxc
            if self.adiis and it > 0:
                return transmat(self.adiis.update(self.s,Pt,self.h + Veff),\
                self.x), jmat, kmat
                return transmat(self.h + Veff,self.x), jmat, kmat

    def get_vxc(self,dm):
        Update exchange matrices and energy
            dm: float or complex
                AO density matrix.

            vxc: float or complex
                exchange-correlation matrix in AO basis
            excsum: float
                exchange-correlation energy
            kmat: float or complex
                Exact Exchange in AO basis

        nelec, excsum, vxc = self.ks._numint.nr_vxc(self.ks.mol, \
        self.ks.grids, self.ks.xc, dm)
        self.exc = excsum
        vxc  = vxc.astype(complex)
        if(self.hyb > 0.01):
            kmat = self.get_k(dm)
            vxc += -0.5 * self.hyb * kmat
            kmat = None
        return vxc, excsum, kmat

    def get_jk(self, dm):
        Update Coulomb and Exact Exchange Matrix

            dm: float or complex
                AO density matrix.
            jmat: float or complex
                Coulomb matrix in AO basis
            kmat: float or complex
                Exact Exchange in AO basis
        jmat = self.get_j(dm)
        kmat = self.get_k(dm)
        return jmat, kmat

    def get_j(self,dm):
        Update Coulomb Matrix

            dm: float or complex
                AO density matrix.
            jmat: float or complex
                Coulomb matrix in AO basis
        rho = np.einsum("ijp,ij->p", self.eri3c, dm)
        rho = np.linalg.solve(self.eri2c, rho)
        jmat = np.einsum("p,ijp->ij", rho, self.eri3c)
        return jmat

    def get_k(self,dm):
        Update Exact Exchange Matrix

            dm: float or complex
                AO density matrix.
            kmat: float or complex
                Exact Exchange in AO basis
        naux = self.auxmol.nao_nr()
        nao = self.ks.mol.nao_nr()
        kpj = np.einsum("ijp,jk->ikp", self.eri3c, dm)
        pik = np.linalg.solve(self.eri2c, kpj.reshape(-1,naux).T.conj())
        kmat = np.einsum("pik,kjp->ij", pik.reshape(naux,nao,nao), self.eri3c)
        return kmat

    def initialcondition(self,prm):
        Prepare the variables/Matrices needed for propagation
        The SCF is done here to make matrices that are not accessable from pyscf.scf
            prm: str
                string object with |variable    value| on each line
            fmat: float or complex
                Fock matrix in Lowdin AO basis
            c_am: float
                Transformation Matrix |AO><MO|
            v_lm: float
                Transformation Matrix |LAO><MO|
            rho: float or complex
                Initial MO density matrix.

        from pyscf.rt import tdfields
        self.auxmol_set(self.ks.mol, auxbas = self.auxbas)
        self.params = dict()

            |  Realtime TDSCF module          |
            | J. Parkhill, T. Nguyen          |
            | J. Koh, J. Herr,  K. Yao        |
            | Refs: 10.1021/acs.jctc.5b00262  |
            |       10.1063/1.4916822         |
        n_ao = self.ks.mol.nao_nr()
        n_occ = int(sum(self.ks.mo_occ)/2)
        logger.log(self,"n_ao: %d        n_occ: %d", n_ao, n_occ)
        fmat, c_am, v_lm = self.initfockbuild() # updates self.C
        rho = 0.5*np.diag(self.ks.mo_occ).astype(complex)
        self.field = tdfields.FIELDS(self, self.params)
        self.field.initializeexpectation(rho, c_am)
        return fmat, c_am, v_lm, rho

    def readparams(self,prm):
        Set Defaults, Read the file and fill the params dictionary

            prm: str
                string object with |variable    value| on each line
        self.params["Model"] = "TDDFT"
        self.params["Method"] = "MMUT"

        self.params["dt"] =  0.02
        self.params["MaxIter"] = 15000

        self.params["ExDir"] = 1.0
        self.params["EyDir"] = 1.0
        self.params["EzDir"] = 1.0
        self.params["FieldAmplitude"] = 0.01
        self.params["FieldFreq"] = 0.9202
        self.params["Tau"] = 0.07
        self.params["tOn"] = 7.0*self.params["Tau"]
        self.params["ApplyImpulse"] = 1
        self.params["ApplyCw"] = 0

        self.params["StatusEvery"] = 5000
        # Here they should be read from disk.
        if(prm != None):
            for line in prm.splitlines():
                s = line.split()
                if len(s) > 1:
                    if s[0] == "MaxIter" or s[0] == str("ApplyImpulse") or \
                    s[0] == str("ApplyCw") or s[0] == str("StatusEvery"):
                        self.params[s[0]] = int(s[1])
                    elif s[0] == "Model" or s[0] == "Method":
                        self.params[s[0]] = s[1].upper()
                        self.params[s[0]] = float(s[1])

        logger.log(self,"         Parameters")
        logger.log(self,"Model: " + self.params["Model"].upper())
        logger.log(self,"Method: "+ self.params["Method"].upper())
        logger.log(self,"dt: %.2f", self.params["dt"])
        logger.log(self,"MaxIter: %d", self.params["MaxIter"])
        logger.log(self,"ExDir: %.2f", self.params["ExDir"])
        logger.log(self,"EyDir: %.2f", self.params["EyDir"])
        logger.log(self,"EzDir: %.2f", self.params["EzDir"])
        logger.log(self,"FieldAmplitude: %.4f", self.params["FieldAmplitude"])
        logger.log(self,"FieldFreq: %.4f", self.params["FieldFreq"])
        logger.log(self,"Tau: %.2f", self.params["Tau"])
        logger.log(self,"tOn: %.2f", self.params["tOn"])
        logger.log(self,"ApplyImpulse: %d", self.params["ApplyImpulse"])
        logger.log(self,"ApplyCw: %d", self.params["ApplyCw"])
        logger.log(self,"StatusEvery: %d", self.params["StatusEvery"])


    def initfockbuild(self):
        Using Roothan's equation to build a Initial Fock matrix and
        Transformation Matrices

            fmat: float or complex
                Fock matrix in Lowdin AO basis
            c_am: float
                Transformation Matrix |AO><MO|
            v_lm: float
                Transformation Matrix |LAO><MO|
        start = time.time()
        n_occ = int(sum(self.ks.mo_occ)/2)
        err = 100
        it = 0
        self.h = self.ks.get_hcore()
        s = self.s.copy()
        x = self.x.copy()
        sx = np.dot(s,x)
        dm_lao = 0.5*transmat(self.ks.get_init_guess(self.ks.mol, \
        self.ks.init_guess), sx).astype(complex)

        if isinstance(self.ks.diis, lib.diis.DIIS):
            self.adiis = self.ks.diis
        elif self.ks.diis:
            self.adiis = diis.SCF_DIIS(self.ks, self.ks.diis_file)
            self.adiis.space = self.ks.diis_space
            self.adiis.rollback = self.ks.diis_space_rollback
            self.adiis = None

        fmat, jmat, kmat = self.fockbuild(dm_lao)
        etot = self.energy(dm_lao,fmat, jmat, kmat)+ self.enuc

        while (err > self.conv_tol):
            # Diagonalize F in the lowdin basis
            eigs, v_lm = np.linalg.eig(fmat)
            idx = eigs.argsort()
            v_lm = v_lm[:,idx].copy()
            # Fill up the density in the MO basis and then Transform back
            rho = 0.5*np.diag(self.ks.mo_occ).astype(complex)
            dm_lao = transmat(rho,v_lm,-1)
            etot_old = etot
            etot = self.energy(dm_lao,fmat, jmat, kmat)
            fmat, jmat, kmat = self.fockbuild(dm_lao,it)
            err = abs(etot-etot_old)
            logger.debug(self, "Ne: %f", np.trace(rho))
            logger.debug(self, "Iteration: %d         Energy: %.11f      \
            Error = %.11f", it, etot, err)
            it += 1
            if it > self.ks.max_cycle:
                logger.log(self, "Max cycle of SCF reached: %d\n Exiting TDSCF. Please raise ks.max_cycle", it)
        rho = 0.5*np.diag(self.ks.mo_occ).astype(complex)
        dm_lao = transmat(rho,v_lm,-1)
        c_am = np.dot(self.x,v_lm)
        logger.log(self, "Ne: %f", np.trace(rho))
        logger.log(self, "Converged Energy: %f", etot)
        # logger.log(self, "Eigenvalues: %f", eigs.real)
        # print "Eigenvalues: ", eigs.real
        end = time.time()
        logger.info(self, "Initial Fock Built time: %f", end-start)
        return fmat, c_am, v_lm

    def split_rk4_step_mmut(self, w, v , oldrho , tnow, dt ,IsOn):
        Ud = np.exp(w*(-0.5j)*dt);
        U = transmat(np.diag(Ud),v,-1)
        RhoHalfStepped = transmat(oldrho,U,-1)
        # If any TCL propagation occurs...
        # DontDo=
        # SplitLiouvillian( RhoHalfStepped, k1,tnow,IsOn);
        # v2 = (dt/2.0) * k1;
        # v2 += RhoHalfStepped;
        # SplitLiouvillian(  v2, k2,tnow+(dt/2.0),IsOn);
        # v3 = (dt/2.0) * k2;
        # v3 += RhoHalfStepped;
        # SplitLiouvillian(  v3, k3,tnow+(dt/2.0),IsOn);
        # v4 = (dt) * k3;
        # v4 += RhoHalfStepped;
        # SplitLiouvillian(  v4, k4,tnow+dt,IsOn);
        # newrho = RhoHalfStepped;
        # newrho += dt*(1.0/6.0)*k1;
        # newrho += dt*(2.0/6.0)*k2;
        # newrho += dt*(2.0/6.0)*k3;
        # newrho += dt*(1.0/6.0)*k4;
        # newrho = U*newrho*U.t();
        newrho = transmat(RhoHalfStepped,U,-1)

        return newrho

    def tddftstep(self,fmat, c_am, v_lm, rho, rhom12, tnow):
        Take dt step in propagation
        updates matrices and rho to next timestep
            fmat: float or complex
                Fock matrix in Lowdin AO basis
            c_am: float or complex
                Transformation Matrix |AO><MO|
            v_lm: float or complex
                Transformation Matrix |LAO><MO|
            rho: complex
                MO density matrix.
            rhom12: complex
            tnow: float
                current time in A.U.
            n_rho: complex
                MO density matrix.
            n_rhom12: complex
            n_c_am: complex
                Transformation Matrix |AO><MO|
            n_v_lm: complex
                Transformation Matrix |LAO><MO|
            n_fmat: complex
                Fock matrix in Lowdin AO basis
            n_jmat: complex
                Coulomb matrix in AO basis
            n_kmat: complex
                Exact Exchange in AO basis
        if (self.params["Method"] == "MMUT"):
            fmat, n_jmat, n_kmat = self.fockbuild(transmat(rho,v_lm,-1))
            n_fmat = fmat.copy()
            fmat_c = np.conj(fmat)
            fmat_prev = transmat(fmat_c, v_lm)
            eigs, rot = np.linalg.eig(fmat_prev)
            idx = eigs.argsort()
            rot = rot[:,idx].copy()
            rho = transmat(rho, rot)
            v_lm = np.dot(v_lm , rot)
            c_am = np.dot(self.x , v_lm)
            n_v_lm = v_lm.copy()
            n_c_am = c_am.copy()
            fmat_mo = np.diag(eigs).astype(complex)
            fmatfield, IsOn = self.field.applyfield(fmat_mo,c_am,tnow)
            w,v = scipy.linalg.eig(fmatfield)
            NewRhoM12 = self.split_rk4_step_mmut(w, v, rhom12, tnow, \
            self.params["dt"], IsOn)
            NewRho = self.split_rk4_step_mmut(w, v, NewRhoM12, tnow,\
            self.params["dt"]/2.0, IsOn)
            n_rho = 0.5*(NewRho+(NewRho.T.conj()));
            n_rhom12 = 0.5*(NewRhoM12+(NewRhoM12.T.conj()))
            return n_rho, n_rhom12, n_c_am, n_v_lm, n_fmat, n_jmat, n_kmat
            raise Exception("Unknown Method...")

    def dipole(self, rho, c_am):
            c_am: float or complex
                Transformation Matrix |AO><MO|
            rho: complex
                MO density matrix.
            dipole: float
                xyz component of dipole of a molecule. [x y z]
        return self.field.expectation(rho, c_am)

    def energy(self,dm_lao,fmat,jmat,kmat):
            dm_lao: complex
                Density in LAO basis.
            fmat: complex
                Fock matrix in Lowdin AO basis
            jmat: complex
                Coulomb matrix in AO basis
            kmat: complex
                Exact Exchange in AO basis
            e_tot: float
                Total Energy of a system
        if (self.params["Model"] == "TDHF"):
            hlao = transmat(self.h,self.x)
            e_tot = (self.enuc+np.trace(np.dot(dm_lao,hlao+fmat))).real
            return e_tot
        elif self.params["Model"] == "TDDFT":
            dm = transmat(dm_lao,self.x,-1)
            exc = self.exc
            if(self.hyb > 0.01):
                exc -= 0.5 * self.hyb * trdot(dm,kmat)
            # if not using auxmol
            eh = trdot(dm,2*self.h)
            ej = trdot(dm,jmat)
            e_tot = (eh + ej + exc + self.enuc).real
            return e_tot

    def loginstant(self, rho, c_am, v_lm, fmat, jmat, kmat, tnow, it):
        time is logged in atomic units.
            rho: complex
                MO density matrix.
            c_am: complex
                Transformation Matrix |AO><MO|
            v_lm: complex
                Transformation Matrix |LAO><MO|
            fmat: complex
                Fock matrix in Lowdin AO basis
            jmat: complex
                Coulomb matrix in AO basis
            kmat: complex
                Exact Exchange in AO basis
            tnow: float
                Current time in propagation in A.U.
            it: int
                Number of iteration of propagation
            tore: str
                |t, dipole(x,y,z), energy|

        np.set_printoptions(precision = 7)
        tore = str(tnow)+" "+str(self.dipole(rho, c_am).real).rstrip("]").lstrip("[")+\
         " " +str(self.energy(transmat(rho,v_lm,-1),fmat, jmat, kmat))

        if it%self.params["StatusEvery"] ==0 or it == self.params["MaxIter"]-1:
            logger.log(self, "t: %f fs    Energy: %f a.u.   Total Density: %f",\
            tnow*FSPERAU,self.energy(transmat(rho,v_lm,-1),fmat, jmat, kmat), \
            logger.log(self, "Dipole moment(X, Y, Z, au): %8.5f, %8.5f, %8.5f",\
             self.dipole(rho, c_am).real[0],self.dipole(rho, c_am).real[1],\
             self.dipole(rho, c_am).real[2])
        return tore

    def prop(self, fmat, c_am, v_lm, rho, output):
        The main tdscf propagation loop.
            fmat: complex
                Fock matrix in Lowdin AO basis
            c_am: complex
                Transformation Matrix |AO><MO|
            v_lm: complex
                Transformation Matrix |LAO><MO|
            rho: complex
                MO density matrix.
            output: str
                name of the file with result of propagation
        Saved results:
            f: file
                output file with |t, dipole(x,y,z), energy|
        it = 0
        tnow = 0
        rhom12 = rho.copy()
        f = open(output,"a")
        logger.log(self,"\n\nPropagation Begins")
        start = time.time()
        while (it<self.params["MaxIter"]):
            rho, rhom12, c_am, v_lm, fmat, jmat, kmat = self.tddftstep(fmat, c_am, v_lm, rho, rhom12, tnow)
            # rho = newrho.copy()
            # rhom12 = newrhom12.copy()
            f.write(self.loginstant(rho, c_am, v_lm, fmat, jmat, kmat, tnow, it)+"\n")
            # Do logging.
            tnow = tnow + self.params["dt"]
            if it%self.params["StatusEvery"] ==0 or \
            it == self.params["MaxIter"]-1:
                end = time.time()
                logger.log(self, "%f hr/ps", \
                (end - start)/(60*60*tnow * FSPERAU * 0.001))
            it = it + 1
