# Copyright 2014-2018 The PySCF Developers. All Rights Reserved.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import unittest
import numpy
import numpy as np

##################################################
#
# port from ao2mo/eris.py
#
##################################################
from pyscf import lib
from pyscf.pbc import lib as pbclib
from pyscf.pbc.dft import gen_grid
from pyscf.pbc.dft import numint
from pyscf.pbc import tools
from pyscf.pbc.lib import kpts_helper

#einsum = np.einsum
einsum = lib.einsum

"""
(ij|kl) = \int dr1 dr2 i*(r1) j(r1) v(r12) k*(r2) l(r2)
= (ij|G) v(G) (G|kl)

i*(r) j(r) = 1/N \sum_G e^{iGr}  (G|ij)
= 1/N \sum_G e^{-iGr} (ij|G)

"forward" FFT:
(G|ij) = \sum_r e^{-iGr} i*(r) j(r) = fft[ i*(r) j(r) ]
"inverse" FFT:
(ij|G) = \sum_r e^{iGr} i*(r) j(r) = N * ifft[ i*(r) j(r) ]
= conj[ \sum_r e^{-iGr} j*(r) i(r) ]
"""

def general(cell, mo_coeffs, kpts=None, compact=0):
'''pyscf-style wrapper to get MO 2-el integrals.'''
assert len(mo_coeffs) == 4
if kpts is not None:
assert len(kpts) == 4
return get_mo_eri(cell, mo_coeffs, kpts)

def get_mo_eri(cell, mo_coeffs, kpts=None):
'''Convenience function to return MO 2-el integrals.'''
mo_coeff12 = mo_coeffs[:2]
mo_coeff34 = mo_coeffs[2:]
if kpts is None:
kpts12 = kpts34 = q = None
else:
kpts12 = kpts[:2]
kpts34 = kpts[2:]
q = kpts12[0] - kpts12[1]
#q = kpts34[1] - kpts34[0]
if q is None:
q = np.zeros(3)

mo_pairs12_kG = get_mo_pairs_G(cell, mo_coeff12, kpts12)
mo_pairs34_invkG = get_mo_pairs_invG(cell, mo_coeff34, kpts34, q)
return assemble_eri(cell, mo_pairs12_kG, mo_pairs34_invkG, q)

def get_mo_pairs_G(cell, mo_coeffs, kpts=None, q=None):
'''Calculate forward (G|ij) FFT of all MO pairs.

TODO: - Implement simplifications for real orbitals.

Args:
mo_coeff: length-2 list of (nao,nmo) ndarrays
The two sets of MO coefficients to use in calculating the
product |ij).

Returns:
mo_pairs_G : (ngrids, nmoi*nmoj) ndarray
The FFT of the real-space MO pairs.
'''
coords = gen_grid.gen_uniform_grids(cell)
if kpts is None:
q = np.zeros(3)
aoR = numint.eval_ao(cell, coords)
ngrids = aoR.shape[0]

if np.array_equal(mo_coeffs[0], mo_coeffs[1]):
nmoi = nmoj = mo_coeffs[0].shape[1]
moiR = mojR = einsum('ri,ia->ra', aoR, mo_coeffs[0])
else:
nmoi = mo_coeffs[0].shape[1]
nmoj = mo_coeffs[1].shape[1]
moiR = einsum('ri,ia->ra', aoR, mo_coeffs[0])
mojR = einsum('ri,ia->ra', aoR, mo_coeffs[1])

else:
if q is None:
q = kpts[1]-kpts[0]
aoR_ki = numint.eval_ao(cell, coords, kpt=kpts[0])
aoR_kj = numint.eval_ao(cell, coords, kpt=kpts[1])
ngrids = aoR_ki.shape[0]

nmoi = mo_coeffs[0].shape[1]
nmoj = mo_coeffs[1].shape[1]
moiR = einsum('ri,ia->ra', aoR_ki, mo_coeffs[0])
mojR = einsum('ri,ia->ra', aoR_kj, mo_coeffs[1])

#mo_pairs_R = einsum('ri,rj->rij', np.conj(moiR), mojR)
mo_pairs_G = np.zeros([ngrids,nmoi*nmoj], np.complex128)

fac = np.exp(-1j*np.dot(coords, q))
for i in range(nmoi):
for j in range(nmoj):
mo_pairs_R_ij = np.conj(moiR[:,i])*mojR[:,j]
mo_pairs_G[:,i*nmoj+j] = tools.fftk(mo_pairs_R_ij, cell.mesh, fac)

return mo_pairs_G

def get_mo_pairs_invG(cell, mo_coeffs, kpts=None, q=None):
'''Calculate "inverse" (ij|G) FFT of all MO pairs.

TODO: - Implement simplifications for real orbitals.

Args:
mo_coeff: length-2 list of (nao,nmo) ndarrays
The two sets of MO coefficients to use in calculating the
product |ij).

Returns:
mo_pairs_invG : (ngrids, nmoi*nmoj) ndarray
The inverse FFTs of the real-space MO pairs.
'''
coords = gen_grid.gen_uniform_grids(cell)
if kpts is None:
q = np.zeros(3)
aoR = numint.eval_ao(cell, coords)
ngrids = aoR.shape[0]

if np.array_equal(mo_coeffs[0], mo_coeffs[1]):
nmoi = nmoj = mo_coeffs[0].shape[1]
moiR = mojR = einsum('ri,ia->ra', aoR, mo_coeffs[0])
else:
nmoi = mo_coeffs[0].shape[1]
nmoj = mo_coeffs[1].shape[1]
moiR = einsum('ri,ia->ra', aoR, mo_coeffs[0])
mojR = einsum('ri,ia->ra', aoR, mo_coeffs[1])

else:
if q is None:
q = kpts[1]-kpts[0]
aoR_ki = numint.eval_ao(cell, coords, kpt=kpts[0])
aoR_kj = numint.eval_ao(cell, coords, kpt=kpts[1])
ngrids = aoR_ki.shape[0]

nmoi = mo_coeffs[0].shape[1]
nmoj = mo_coeffs[1].shape[1]
moiR = einsum('ri,ia->ra', aoR_ki, mo_coeffs[0])
mojR = einsum('ri,ia->ra', aoR_kj, mo_coeffs[1])

#mo_pairs_R = einsum('ri,rj->rij', np.conj(moiR), mojR)
mo_pairs_invG = np.zeros([ngrids,nmoi*nmoj], np.complex128)

fac = np.exp(1j*np.dot(coords, q))
for i in range(nmoi):
for j in range(nmoj):
mo_pairs_R_ij = np.conj(moiR[:,i])*mojR[:,j]
mo_pairs_invG[:,i*nmoj+j] = np.conj(tools.fftk(np.conj(mo_pairs_R_ij), cell.mesh, fac))

return mo_pairs_invG

def get_mo_pairs_G_old(cell, mo_coeffs, kpts=None, q=None):
'''Calculate forward (G|ij) and "inverse" (ij|G) FFT of all MO pairs.

TODO: - Implement simplifications for real orbitals.

Args:
mo_coeff: length-2 list of (nao,nmo) ndarrays
The two sets of MO coefficients to use in calculating the
product |ij).

Returns:
mo_pairs_G, mo_pairs_invG : (ngrids, nmoi*nmoj) ndarray
The FFTs of the real-space MO pairs.
'''
coords = gen_grid.gen_uniform_grids(cell)
if kpts is None:
q = np.zeros(3)
aoR = numint.eval_ao(cell, coords)
ngrids = aoR.shape[0]

if np.array_equal(mo_coeffs[0], mo_coeffs[1]):
nmoi = nmoj = mo_coeffs[0].shape[1]
moiR = mojR = einsum('ri,ia->ra', aoR, mo_coeffs[0])
else:
nmoi = mo_coeffs[0].shape[1]
nmoj = mo_coeffs[1].shape[1]
moiR = einsum('ri,ia->ra', aoR, mo_coeffs[0])
mojR = einsum('ri,ia->ra', aoR, mo_coeffs[1])

else:
if q is None:
q = kpts[1]-kpts[0]
aoR_ki = numint.eval_ao(cell, coords, kpt=kpts[0])
aoR_kj = numint.eval_ao(cell, coords, kpt=kpts[1])
ngrids = aoR_ki.shape[0]

nmoi = mo_coeffs[0].shape[1]
nmoj = mo_coeffs[1].shape[1]
moiR = einsum('ri,ia->ra', aoR_ki, mo_coeffs[0])
mojR = einsum('ri,ia->ra', aoR_kj, mo_coeffs[1])

mo_pairs_R = np.einsum('ri,rj->rij', np.conj(moiR), mojR)
mo_pairs_G = np.zeros([ngrids,nmoi*nmoj], np.complex128)
mo_pairs_invG = np.zeros([ngrids,nmoi*nmoj], np.complex128)

fac = np.exp(-1j*np.dot(coords, q))
for i in range(nmoi):
for j in range(nmoj):
mo_pairs_G[:,i*nmoj+j] = tools.fftk(mo_pairs_R[:,i,j], cell.mesh, fac)
mo_pairs_invG[:,i*nmoj+j] = np.conj(tools.fftk(np.conj(mo_pairs_R[:,i,j]), cell.mesh,
fac.conj()))

return mo_pairs_G, mo_pairs_invG

def assemble_eri(cell, orb_pair_invG1, orb_pair_G2, q=None):
'''Assemble 4-index electron repulsion integrals.

Returns:
(nmo1*nmo2, nmo3*nmo4) ndarray

'''
if q is None:
q = np.zeros(3)

coulqG = tools.get_coulG(cell, -1.0*q)
ngrids = orb_pair_invG1.shape[0]
Jorb_pair_G2 = np.einsum('g,gn->gn',coulqG,orb_pair_G2)*(cell.vol/ngrids**2)
eri = np.dot(orb_pair_invG1.T, Jorb_pair_G2)
return eri

def get_ao_pairs_G(cell, kpt=np.zeros(3)):
'''Calculate forward (G|ij) and "inverse" (ij|G) FFT of all AO pairs.

Args:
cell : instance of :class:Cell

Returns:
ao_pairs_G, ao_pairs_invG : (ngrids, nao*(nao+1)/2) ndarray
The FFTs of the real-space AO pairs.

'''
coords = gen_grid.gen_uniform_grids(cell)
aoR = numint.eval_ao(cell, coords, kpt) # shape = (coords, nao)
ngrids, nao = aoR.shape
gamma_point = abs(kpt).sum() < 1e-9
if gamma_point:
npair = nao*(nao+1)//2
ao_pairs_G = np.empty([ngrids, npair], np.complex128)

ij = 0
for i in range(nao):
for j in range(i+1):
ao_ij_R = np.conj(aoR[:,i]) * aoR[:,j]
ao_pairs_G[:,ij] = tools.fft(ao_ij_R, cell.mesh)
#ao_pairs_invG[:,ij] = ngrids*tools.ifft(ao_ij_R, cell.mesh)
ij += 1
ao_pairs_invG = ao_pairs_G.conj()
else:
ao_pairs_G = np.zeros([ngrids, nao,nao], np.complex128)
for i in range(nao):
for j in range(nao):
ao_ij_R = np.conj(aoR[:,i]) * aoR[:,j]
ao_pairs_G[:,i,j] = tools.fft(ao_ij_R, cell.mesh)
ao_pairs_invG = ao_pairs_G.transpose(0,2,1).conj().reshape(-1,nao**2)
ao_pairs_G = ao_pairs_G.reshape(-1,nao**2)
return ao_pairs_G, ao_pairs_invG

def get_ao_eri(cell, kpt=np.zeros(3)):
'''Convenience function to return AO 2-el integrals.'''

ao_pairs_G, ao_pairs_invG = get_ao_pairs_G(cell, kpt)
eri = assemble_eri(cell, ao_pairs_invG, ao_pairs_G)
if abs(kpt).sum() < 1e-9:
eri = eri.real
return eri

##################################################
#
# ao2mo/eris.py end
#
##################################################

##################################################
#
# port from scf/hf.py
#
##################################################
from pyscf.pbc import scf as pbcscf

def get_j(cell, dm, hermi=1, vhfopt=None, kpt=np.zeros(3), kpts_band=None):
dm = np.asarray(dm)
nao = dm.shape[-1]

coords = gen_grid.gen_uniform_grids(cell)
if kpts_band is None:
kpt1 = kpt2 = kpt
aoR_k1 = aoR_k2 = numint.eval_ao(cell, coords, kpt)
else:
kpt1 = kpts_band
kpt2 = kpt
aoR_k1 = numint.eval_ao(cell, coords, kpt1)
aoR_k2 = numint.eval_ao(cell, coords, kpt2)
ngrids, nao = aoR_k1.shape

def contract(dm):
vjR_k2 = get_vjR(cell, dm, aoR_k2)
vj = (cell.vol/ngrids) * np.dot(aoR_k1.T.conj(), vjR_k2.reshape(-1,1)*aoR_k1)
return vj

if dm.ndim == 2:
vj = contract(dm)
else:
vj = lib.asarray([contract(x) for x in dm.reshape(-1,nao,nao)])
return vj.reshape(dm.shape)

def get_jk(mf, cell, dm, hermi=1, vhfopt=None, kpt=np.zeros(3), kpts_band=None):
dm = np.asarray(dm)
nao = dm.shape[-1]

coords = gen_grid.gen_uniform_grids(cell)
if kpts_band is None:
kpt1 = kpt2 = kpt
aoR_k1 = aoR_k2 = numint.eval_ao(cell, coords, kpt)
else:
kpt1 = kpts_band
kpt2 = kpt
aoR_k1 = numint.eval_ao(cell, coords, kpt1)
aoR_k2 = numint.eval_ao(cell, coords, kpt2)

vkR_k1k2 = get_vkR(mf, cell, aoR_k1, aoR_k2, kpt1, kpt2)

ngrids, nao = aoR_k1.shape
def contract(dm):
vjR_k2 = get_vjR(cell, dm, aoR_k2)
vj = (cell.vol/ngrids) * np.dot(aoR_k1.T.conj(), vjR_k2.reshape(-1,1)*aoR_k1)

#:vk = (cell.vol/ngrids) * np.einsum('rs,Rp,Rqs,Rr->pq', dm, aoR_k1.conj(),
#:                                vkR_k1k2, aoR_k2)
aoR_dm_k2 = np.dot(aoR_k2, dm)
tmp_Rq = np.einsum('Rqs,Rs->Rq', vkR_k1k2, aoR_dm_k2)
vk = (cell.vol/ngrids) * np.dot(aoR_k1.T.conj(), tmp_Rq)
return vj, vk

if dm.ndim == 2:
vj, vk = contract(dm)
else:
jk = [contract(x) for x in dm.reshape(-1,nao,nao)]
vj = lib.asarray([x[0] for x in jk])
vk = lib.asarray([x[1] for x in jk])
return vj.reshape(dm.shape), vk.reshape(dm.shape)

def get_vjR(cell, dm, aoR):
coulG = tools.get_coulG(cell)

rhoR = numint.eval_rho(cell, aoR, dm)
rhoG = tools.fft(rhoR, cell.mesh)

vG = coulG*rhoG
vR = tools.ifft(vG, cell.mesh)
if rhoR.dtype == np.double:
vR = vR.real
return vR

def get_vkR(mf, cell, aoR_k1, aoR_k2, kpt1, kpt2):
'''Get the real-space 2-index "exchange" potential V_{i,k1; j,k2}(r)
where {i,k1} = exp^{i k1 r) |i> , {j,k2} = exp^{-i k2 r) <j|
'''
coords = gen_grid.gen_uniform_grids(cell)
ngrids, nao = aoR_k1.shape

expmikr = np.exp(-1j*np.dot(kpt1-kpt2,coords.T))
coulG = tools.get_coulG(cell, kpt1-kpt2, exx=True, mf=mf)
def prod(ij):
i, j = divmod(ij, nao)
rhoR = aoR_k1[:,i] * aoR_k2[:,j].conj()
rhoG = tools.fftk(rhoR, cell.mesh, expmikr)
vG = coulG*rhoG
vR = tools.ifftk(vG, cell.mesh, expmikr.conj())
return vR

if aoR_k1.dtype == np.double and aoR_k2.dtype == np.double:
vR = numpy.asarray([prod(ij).real for ij in range(nao**2)])
else:
vR = numpy.asarray([prod(ij) for ij in range(nao**2)])
return vR.reshape(nao,nao,-1).transpose(2,0,1)

def get_j_kpts(mf, cell, dm_kpts, kpts, kpts_band=None):
coords = gen_grid.gen_uniform_grids(cell)
nkpts = len(kpts)
ngrids = len(coords)
dm_kpts = np.asarray(dm_kpts)
nao = dm_kpts.shape[-1]

ni = numint.KNumInt(kpts)
aoR_kpts = ni.eval_ao(cell, coords, kpts)
if kpts_band is not None:
aoR_kband = numint.eval_ao(cell, coords, kpts_band)

dms = dm_kpts.reshape(-1,nkpts,nao,nao)
nset = dms.shape[0]

vjR = [get_vjR(cell, dms[i], aoR_kpts) for i in range(nset)]
if kpts_band is not None:
vj_kpts = [cell.vol/ngrids * lib.dot(aoR_kband.T.conj()*vjR[i], aoR_kband)
for i in range(nset)]
if dm_kpts.ndim == 3:  # One set of dm_kpts for KRHF
vj_kpts = vj_kpts[0]
return lib.asarray(vj_kpts)
else:
vj_kpts = []
for i in range(nset):
vj = [cell.vol/ngrids * lib.dot(aoR_k.T.conj()*vjR[i], aoR_k)
for aoR_k in aoR_kpts]
vj_kpts.append(lib.asarray(vj))
return lib.asarray(vj_kpts).reshape(dm_kpts.shape)

def get_jk_kpts(mf, cell, dm_kpts, kpts, kpts_band=None):
coords = gen_grid.gen_uniform_grids(cell)
nkpts = len(kpts)
ngrids = len(coords)
dm_kpts = np.asarray(dm_kpts)
nao = dm_kpts.shape[-1]

dms = dm_kpts.reshape(-1,nkpts,nao,nao)
nset = dms.shape[0]

ni = numint.KNumInt(kpts)
aoR_kpts = ni.eval_ao(cell, coords, kpts)
if kpts_band is not None:
aoR_kband = numint.eval_ao(cell, coords, kpts_band)

# J
vjR = [get_vjR_kpts(cell, dms[i], aoR_kpts) for i in range(nset)]
if kpts_band is not None:
vj_kpts = [cell.vol/ngrids * lib.dot(aoR_kband.T.conj()*vjR[i], aoR_kband)
for i in range(nset)]
else:
vj_kpts = []
for i in range(nset):
vj = [cell.vol/ngrids * lib.dot(aoR_k.T.conj()*vjR[i], aoR_k)
for aoR_k in aoR_kpts]
vj_kpts.append(lib.asarray(vj))
vj_kpts = lib.asarray(vj_kpts)
vjR = None

# K
weight = 1./nkpts * (cell.vol/ngrids)
vk_kpts = np.zeros_like(vj_kpts)
if kpts_band is not None:
for k2, kpt2 in enumerate(kpts):
aoR_dms = [lib.dot(aoR_kpts[k2], dms[i,k2]) for i in range(nset)]
vkR_k1k2 = get_vkR(mf, cell, aoR_kband, aoR_kpts[k2],
kpts_band, kpt2)
#:vk_kpts = 1./nkpts * (cell.vol/ngrids) * np.einsum('rs,Rp,Rqs,Rr->pq',
#:            dm_kpts[k2], aoR_kband.conj(),
#:            vkR_k1k2, aoR_kpts[k2])
for i in range(nset):
tmp_Rq = np.einsum('Rqs,Rs->Rq', vkR_k1k2, aoR_dms[i])
vk_kpts[i] += weight * lib.dot(aoR_kband.T.conj(), tmp_Rq)
vkR_k1k2 = None
if dm_kpts.ndim == 3:
vj_kpts = vj_kpts[0]
vk_kpts = vk_kpts[0]
return lib.asarray(vj_kpts), lib.asarray(vk_kpts)
else:
for k2, kpt2 in enumerate(kpts):
aoR_dms = [lib.dot(aoR_kpts[k2], dms[i,k2]) for i in range(nset)]
for k1, kpt1 in enumerate(kpts):
vkR_k1k2 = get_vkR(mf, cell, aoR_kpts[k1], aoR_kpts[k2],
kpt1, kpt2)
for i in range(nset):
tmp_Rq = np.einsum('Rqs,Rs->Rq', vkR_k1k2, aoR_dms[i])
vk_kpts[i,k1] += weight * lib.dot(aoR_kpts[k1].T.conj(), tmp_Rq)
vkR_k1k2 = None
return vj_kpts.reshape(dm_kpts.shape), vk_kpts.reshape(dm_kpts.shape)

def get_vjR_kpts(cell, dm_kpts, aoR_kpts):
nkpts = len(aoR_kpts)
coulG = tools.get_coulG(cell)

rhoR = 0
for k in range(nkpts):
rhoR += 1./nkpts*numint.eval_rho(cell, aoR_kpts[k], dm_kpts[k])
rhoG = tools.fft(rhoR, cell.mesh)

vG = coulG*rhoG
vR = tools.ifft(vG, cell.mesh)
if rhoR.dtype == np.double:
vR = vR.real
return vR

##################################################
#
# scf/hf.py end
#
##################################################

def get_nuc(cell, kpt=np.zeros(3)):
'''Get the bare periodic nuc-el AO matrix, with G=0 removed.

See Martin (12.16)-(12.21).
'''
coords = gen_grid.gen_uniform_grids(cell)
aoR = numint.eval_ao(cell, coords, kpt)

chargs = cell.atom_charges()
SI = cell.get_SI()
coulG = tools.get_coulG(cell)
vneG = -np.dot(chargs,SI) * coulG
vneR = tools.ifft(vneG, cell.mesh).real

vne = np.dot(aoR.T.conj(), vneR.reshape(-1,1)*aoR)
return vne

from pyscf.pbc import gto as pgto
import pyscf.pbc.dft as pdft
from pyscf.pbc.df import fft, aft

cell = pgto.Cell()
cell.atom = 'He 1. .5 .5; C .1 1.3 2.1'
cell.basis = {'He': [(0, (2.5, 1)), (0, (1., 1))],
'C' :'gth-szv',}
cell.a = np.eye(3) * 2.5
cell.mesh = [21] * 3
cell.build()
np.random.seed(1)
kpts = np.random.random((4,3))
kpts[3] = kpts[0]-kpts[1]+kpts[2]
kpt0 = np.zeros(3)

cell1 = pgto.Cell()
cell1.atom = 'He 1. .5 .5; He .1 1.3 2.1'
cell1.basis = {'He': [(0, (2.5, 1)), (0, (1., 1))]}
cell1.a = np.eye(3) * 2.5
cell1.mesh = [21] * 3
cell1.build()

cell2 = pgto.Cell()
cell2.atom = '''
He   1.3    .2       .3
He    .1    .1      1.1 '''
cell2.basis = {'He': [[0, [0.8, 1]],
[1, [0.6, 1]]]}
cell2.mesh = [17]*3
cell2.a = numpy.array(([2.0,  .9, 0. ],
[0.1, 1.9, 0.4],
[0.8, 0  , 2.1]))
cell2.build()
kpts1 = np.random.random((4,3))
kpts1[3] = kpts1[0]-kpts1[1]+kpts1[2] + cell2.reciprocal_vectors().T.dot(np.ones(3))

mf0 = pbcscf.RHF(cell)
mf0.exxdiv = None

def finger(a):
w = np.cos(np.arange(a.size))
return np.dot(w, a.ravel())

def tearDownModule():
global cell, cell1, cell2, kpts, kpt0, kpts1, mf0
del cell, cell1, cell2, kpts, kpt0, kpts1, mf0

class KnownValues(unittest.TestCase):
def test_get_nuc(self):
v0 = get_nuc(cell, kpts[0])
v1 = fft.FFTDF(cell).get_nuc(kpts)
self.assertTrue(np.allclose(v0, v1[0], atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(v1[0]), (-5.7646608099493841+0.19126294430138713j), 8)

v0 = get_nuc(cell, kpts[1])
self.assertTrue(np.allclose(v0, v1[1], atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(v1[1]), (-5.6567258309199193+0.86813371243952175j), 8)
self.assertAlmostEqual(finger(v1[2]), (-6.1528952645454895+0.09517054428060109j), 8)
self.assertAlmostEqual(finger(v1[3]), (-5.7445962879770942+0.24611951427601772j), 8)

def test_get_pp(self):
v0 = pgto.pseudo.get_pp(cell, kpts[0])
v1 = fft.FFTDF(cell).get_pp(kpts)
self.assertTrue(np.allclose(v0, v1[0], atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(v1[0]), (-5.6240249083785869+0.22094834302524968j), 8)

v0 = pgto.pseudo.get_pp(cell, kpts[1])
self.assertTrue(np.allclose(v0, v1[1], atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(v1[1]), (-5.5387702576467603+1.0439333717227581j) , 8)
self.assertAlmostEqual(finger(v1[2]), (-6.0530899866313366+0.2817289667029651j), 8)
self.assertAlmostEqual(finger(v1[3]), (-5.6011543542444446+0.27597306418805201j), 8)

def test_get_jk(self):
df = fft.FFTDF(cell)
dm = mf0.get_init_guess()
vj0, vk0 = get_jk(mf0, cell, dm, kpt=kpts[0])
vj1, vk1 = df.get_jk(dm, kpts=kpts[0], exxdiv=None)
self.assertTrue(vj1.dtype == numpy.complex128)
self.assertTrue(vk1.dtype == numpy.complex128)
self.assertTrue(np.allclose(vj0, vj1, atol=1e-9, rtol=1e-9))
self.assertTrue(np.allclose(vk0, vk1, atol=1e-9, rtol=1e-9))

ej1 = numpy.einsum('ij,ji->', vj1, dm)
ek1 = numpy.einsum('ij,ji->', vk1, dm)
self.assertAlmostEqual(ej1, 2.3002596914518700*(6/6.82991739766009)**2, 9)
self.assertAlmostEqual(ek1, 3.3165691757797346*(6/6.82991739766009)**2, 9)

dm = mf0.get_init_guess()
vj0, vk0 = get_jk(mf0, cell, dm)
vj1, vk1 = df.get_jk(dm, exxdiv=None)
self.assertTrue(vj1.dtype == numpy.float64)
self.assertTrue(vk1.dtype == numpy.float64)
self.assertTrue(np.allclose(vj0, vj1, atol=1e-9, rtol=1e-9))
self.assertTrue(np.allclose(vk0, vk1, atol=1e-9, rtol=1e-9))

ej1 = numpy.einsum('ij,ji->', vj1, dm)
ek1 = numpy.einsum('ij,ji->', vk1, dm)
self.assertAlmostEqual(ej1, 2.4673139106639925*(6/6.82991739766009)**2, 9)
self.assertAlmostEqual(ek1, 3.6886674521354221*(6/6.82991739766009)**2, 9)

def test_get_jk_kpts(self):
df = fft.FFTDF(cell)
dm = mf0.get_init_guess()
nkpts = len(kpts)
dms = [dm] * nkpts
vj0, vk0 = get_jk_kpts(mf0, cell, dms, kpts=kpts)
vj1, vk1 = df.get_jk(dms, kpts=kpts, exxdiv=None)
self.assertTrue(vj1.dtype == numpy.complex128)
self.assertTrue(vk1.dtype == numpy.complex128)
self.assertTrue(np.allclose(vj0, vj1, atol=1e-9, rtol=1e-9))
self.assertTrue(np.allclose(vk0, vk1, atol=1e-9, rtol=1e-9))

ej1 = numpy.einsum('xij,xji->', vj1, dms) / len(kpts)
ek1 = numpy.einsum('xij,xji->', vk1, dms) / len(kpts)
self.assertAlmostEqual(ej1, 2.3163352969873445*(6/6.82991739766009)**2, 9)
self.assertAlmostEqual(ek1, 7.7311228144548600*(6/6.82991739766009)**2, 9)

numpy.random.seed(1)
kpts_band = numpy.random.random((2,3))
vj1, vk1 = df.get_jk(dms, kpts=kpts, kpts_band=kpts_band, exxdiv=None)
self.assertAlmostEqual(lib.finger(vj1), 6/6.82991739766009*(3.437188138446714+0.1360466492092307j), 9)
self.assertAlmostEqual(lib.finger(vk1), 6/6.82991739766009*(7.479986541097368+1.1980593415201204j), 9)

nao = dm.shape[0]
mo_coeff = numpy.random.random((nkpts,nao,nao))
mo_occ = numpy.array(numpy.random.random((nkpts,nao))>.6, dtype=numpy.double)
dms = numpy.einsum('kpi,ki,kqi->kpq', mo_coeff, mo_occ, mo_coeff)
dms = lib.tag_array(lib.asarray(dms), mo_coeff=mo_coeff, mo_occ=mo_occ)
vk1 = df.get_jk(dms, kpts=kpts, kpts_band=kpts_band, exxdiv=None)[1]
self.assertAlmostEqual(lib.finger(vk1), 10.239828255099447+2.1190549216896182j, 9)

def test_get_j_non_hermitian(self):
kpt = kpts[0]
numpy.random.seed(2)
nao = cell2.nao
dm = numpy.random.random((nao,nao))
mydf = fft.FFTDF(cell2)
v1 = mydf.get_jk(dm, hermi=0, kpts=kpts[1], with_k=False)[0]
eri = mydf.get_eri([kpts[1]]*4).reshape(nao,nao,nao,nao)
ref = numpy.einsum('ijkl,ji->kl', eri, dm)
self.assertAlmostEqual(abs(ref - v1).max(), 0, 12)
self.assertTrue(abs(ref-ref.T.conj()).max() > 1e-5)

def test_get_ao_eri(self):
df = fft.FFTDF(cell)
eri0 = get_ao_eri(cell)
eri1 = df.get_ao_eri(compact=True)
self.assertTrue(np.allclose(eri0, eri1, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri1), 0.80425358275734926, 8)

eri0 = get_ao_eri(cell, kpts[0])
eri1 = df.get_ao_eri(kpts[0])
self.assertTrue(np.allclose(eri0, eri1, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri1), (2.9346374584901898-0.20479054936744959j), 8)

eri4 = df.get_ao_eri(kpts)
self.assertAlmostEqual(finger(eri4), (0.33709288394542991-0.94185725001175313j), 8)

def test_get_eri_gamma(self):
odf = aft.AFTDF(cell1)
ref = odf.get_eri(compact=True)
df = fft.FFTDF(cell1)
eri0000 = df.get_eri(compact=True)
self.assertTrue(eri0000.dtype == numpy.double)
self.assertTrue(np.allclose(eri0000, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri0000), 0.23714016293926865, 9)

ref = odf.get_eri((kpts[0],kpts[0],kpts[0],kpts[0]))
eri1111 = df.get_eri((kpts[0],kpts[0],kpts[0],kpts[0]))
self.assertTrue(np.allclose(eri1111, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri1111), (1.2410388899583582-5.2370501878355006e-06j), 9)

eri1111 = df.get_eri((kpts[0]+1e-8,kpts[0]+1e-8,kpts[0],kpts[0]))
self.assertTrue(np.allclose(eri1111, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri1111), (1.2410388899583582-5.2370501878355006e-06j), 9)

def test_get_eri_0011(self):
odf = aft.AFTDF(cell1)
df = fft.FFTDF(cell1)
ref = odf.get_eri((kpts[0],kpts[0],kpts[1],kpts[1]))
eri0011 = df.get_eri((kpts[0],kpts[0],kpts[1],kpts[1]))
self.assertTrue(np.allclose(eri0011, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri0011), (1.2410162858084512+0.00074485383749912936j), 9)

ref = get_mo_eri(cell1, [numpy.eye(cell1.nao_nr())]*4, (kpts[0],kpts[0],kpts[1],kpts[1]))
eri0011 = df.get_eri((kpts[0],kpts[0],kpts[1],kpts[1]))
self.assertTrue(np.allclose(eri0011, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri0011), (1.2410162860852818+0.00074485383748954838j), 9)

def test_get_eri_0110(self):
odf = aft.AFTDF(cell1)
df = fft.FFTDF(cell1)
ref = odf.get_eri((kpts[0],kpts[1],kpts[1],kpts[0]))
eri0110 = df.get_eri((kpts[0],kpts[1],kpts[1],kpts[0]))
self.assertTrue(np.allclose(eri0110, ref, atol=1e-9, rtol=1e-9))
eri0110 = df.get_eri((kpts[0]+1e-8,kpts[1]+1e-8,kpts[1],kpts[0]))
self.assertTrue(np.allclose(eri0110, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri0110), (1.2928399254827956-0.011820590601969154j), 9)

ref = get_mo_eri(cell1, [numpy.eye(cell1.nao_nr())]*4, (kpts[0],kpts[1],kpts[1],kpts[0]))
eri0110 = df.get_eri((kpts[0],kpts[1],kpts[1],kpts[0]))
self.assertTrue(np.allclose(eri0110, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri0110), (1.2928399254827956-0.011820590601969154j), 9)
eri0110 = df.get_eri((kpts[0]+1e-8,kpts[1]+1e-8,kpts[1],kpts[0]))
self.assertTrue(np.allclose(eri0110, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri0110), (1.2928399254827956-0.011820590601969154j), 9)

def test_get_eri_0123(self):
odf = aft.AFTDF(cell1)
df = fft.FFTDF(cell1)
ref = odf.get_eri(kpts)
eri1111 = df.get_eri(kpts)
self.assertTrue(np.allclose(eri1111, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri1111), (1.2917759427391706-0.013340252488069412j), 9)

ref = get_mo_eri(cell1, [numpy.eye(cell1.nao_nr())]*4, kpts)
eri1111 = df.get_eri(kpts)
self.assertTrue(np.allclose(eri1111, ref, atol=1e-9, rtol=1e-9))
self.assertAlmostEqual(finger(eri1111), (1.2917759427391706-0.013340252488069412j), 9)

def test_get_mo_eri(self):
df = fft.FFTDF(cell)
nao = cell.nao_nr()
numpy.random.seed(5)
mo =(numpy.random.random((nao,nao)) +
numpy.random.random((nao,nao))*1j)
eri_mo0 = get_mo_eri(cell, (mo,)*4, kpts)
eri_mo1 = df.get_mo_eri((mo,)*4, kpts)
self.assertTrue(np.allclose(eri_mo1, eri_mo0, atol=1e-9, rtol=1e-9))

kpts_t = (kpts[2],kpts[3],kpts[0],kpts[1])
eri_mo2 = get_mo_eri(cell, (mo,)*4, kpts_t)
eri_mo2 = eri_mo2.reshape((nao,)*4).transpose(2,3,0,1).reshape(nao**2,-1)
self.assertTrue(np.allclose(eri_mo2, eri_mo0, atol=1e-9, rtol=1e-9))

eri_mo0 = get_mo_eri(cell, (mo,)*4, (kpts[0],)*4)
eri_mo1 = df.get_mo_eri((mo,)*4, (kpts[0],)*4)
self.assertTrue(np.allclose(eri_mo1, eri_mo0, atol=1e-9, rtol=1e-9))

eri_mo0 = get_mo_eri(cell, (mo,)*4, (kpts[0],kpts[1],kpts[1],kpts[0],))
eri_mo1 = df.get_mo_eri((mo,)*4, (kpts[0],kpts[1],kpts[1],kpts[0],))
self.assertTrue(np.allclose(eri_mo1, eri_mo0, atol=1e-9, rtol=1e-9))

eri_mo0 = get_mo_eri(cell, (mo,)*4, (kpt0,kpt0,kpts[0],kpts[0],))
eri_mo1 = df.get_mo_eri((mo,)*4, (kpt0,kpt0,kpts[0],kpts[0],))
self.assertTrue(np.allclose(eri_mo1, eri_mo0, atol=1e-9, rtol=1e-9))

eri_mo0 = get_mo_eri(cell, (mo,)*4, (kpts[0],kpts[0],kpt0,kpt0,))
eri_mo1 = df.get_mo_eri((mo,)*4, (kpts[0],kpts[0],kpt0,kpt0,))
self.assertTrue(np.allclose(eri_mo1, eri_mo0, atol=1e-9, rtol=1e-9))

mo1 = mo[:,:nao//2+1]
eri_mo0 = get_mo_eri(cell, (mo1,mo,mo,mo1), (kpts[0],)*4)
eri_mo1 = df.get_mo_eri((mo1,mo,mo,mo1), (kpts[0],)*4)
self.assertTrue(np.allclose(eri_mo1, eri_mo0, atol=1e-9, rtol=1e-9))

eri_mo0 = get_mo_eri(cell, (mo1,mo,mo1,mo), (kpts[0],kpts[1],kpts[1],kpts[0],))
eri_mo1 = df.get_mo_eri((mo1,mo,mo1,mo), (kpts[0],kpts[1],kpts[1],kpts[0],))
self.assertTrue(np.allclose(eri_mo1, eri_mo0, atol=1e-9, rtol=1e-9))

def test_get_mo_eri1(self):
df = fft.FFTDF(cell2)
nao = cell2.nao_nr()
numpy.random.seed(5)
mos =(numpy.random.random((4,nao,nao)) +
numpy.random.random((4,nao,nao))*1j)
eri_mo0 = get_mo_eri(cell2, mos, kpts1)
eri_mo1 = df.get_mo_eri(mos, kpts1)
self.assertTrue(np.allclose(eri_mo1, eri_mo0, atol=1e-9, rtol=1e-9))

def test_ao2mo_7d(self):
L = 3.
n = 6
cell = pgto.Cell()
cell.a = numpy.diag([L,L,L])
cell.mesh = [n,n,n]
cell.atom = '''He    2.    2.2      2.
He    1.2   1.       1.'''
cell.basis = {'He': [[0, (1.2, 1)], [1, (0.6, 1)]]}
cell.verbose = 0
cell.build(0,0)

kpts = cell.make_kpts([1,3,1])
nkpts = len(kpts)
nao = cell.nao_nr()
numpy.random.seed(1)
mo =(numpy.random.random((nkpts,nao,nao)) +
numpy.random.random((nkpts,nao,nao))*1j)

with_df = fft.FFTDF(cell, kpts)
out = with_df.ao2mo_7d(mo, kpts)
ref = numpy.empty_like(out)

kconserv = kpts_helper.get_kconserv(cell, kpts)
for ki, kj, kk in kpts_helper.loop_kkk(nkpts):
kl = kconserv[ki, kj, kk]
tmp = with_df.ao2mo((mo[ki], mo[kj], mo[kk], mo[kl]), kpts[[ki,kj,kk,kl]])
ref[ki,kj,kk] = tmp.reshape([nao]*4)

self.assertAlmostEqual(abs(out-ref).max(), 0, 12)

def test_get_jk_with_casscf(self):
from pyscf import mcscf
pcell = cell2.copy()
pcell.verbose = 0
pcell.mesh = [8]*3
mf = pbcscf.RHF(pcell)
mf.exxdiv = None
ehf = mf.kernel()

mc = mcscf.CASSCF(mf, 1, 2).run()
self.assertAlmostEqual(mc.e_tot, ehf, 9)

mc = mcscf.CASSCF(mf, 2, 0).run()
self.assertAlmostEqual(mc.e_tot, ehf, 9)

if __name__ == '__main__':
print("Full Tests for fft JK and ao2mo etc")
unittest.main()

#|| ======================================================================
#|| FAIL: test_get_jk (__main__.KnownValues)
#|| ----------------------------------------------------------------------
#|| Traceback (most recent call last):
#||   File "df/test/test_fft.py", line 635, in test_get_jk
#||     self.assertAlmostEqual(ej1, 2.3002596914518700*6./6.82991739766009, 9)
#|| AssertionError: (1.7752048157393747-2.5392293537889846e-19j) != 2.0207503759034617 within 9 places
#||
#|| ======================================================================
#|| FAIL: test_get_jk_kpts (__main__.KnownValues)
#|| ----------------------------------------------------------------------
#|| Traceback (most recent call last):
#||   File "df/test/test_fft.py", line 665, in test_get_jk_kpts
#||     self.assertAlmostEqual(ej1, 2.3163352969873445*6./6.82991739766009, 9)
#|| AssertionError: (1.7876110203371138+9.959029383250026e-19j) != 2.0348726013414873 within 9 places
#||
#|| ======================================================================
#|| FAIL: test_get_jk_with_casscf (__main__.KnownValues)
#|| ----------------------------------------------------------------------
#|| Traceback (most recent call last):
#||   File "df/test/test_fft.py", line 861, in test_get_jk_with_casscf
#||     mc = mcscf.CASSCF(mf, 1, 2).run()
#||   File "/home/sunqm/workspace/program/pyscf/pyscf/lib/misc.py", line 472, in run
#||     self.kernel(*args)
#||   File "/home/sunqm/workspace/program/pyscf/pyscf/mcscf/mc1step.py", line 796, in kernel
#||     ci0=ci0, callback=callback, verbose=self.verbose)
#||   File "/home/sunqm/workspace/program/pyscf/pyscf/mcscf/mc1step.py", line 346, in kernel
#||     e_tot, e_cas, fcivec = casscf.casci(mo, ci0, eris, log, locals())
#||   File "/home/sunqm/workspace/program/pyscf/pyscf/mcscf/mc1step.py", line 816, in casci
#||     e_tot, e_cas, fcivec = casci.kernel(fcasci, mo_coeff, ci0, log)
#||   File "/home/sunqm/workspace/program/pyscf/pyscf/mcscf/casci.py", line 493, in kernel
#||     ecore=energy_core)
#||   File "/home/sunqm/workspace/program/pyscf/pyscf/fci/direct_spin1.py", line 738, in kernel
#||     davidson_only, pspace_size, ecore=ecore, **kwargs)
#||   File "/home/sunqm/workspace/program/pyscf/pyscf/fci/direct_spin1.py", line 448, in kernel_ms1
#||     assert(0 < nelec[0] < norb and 0 < nelec[1] < norb)
#|| AssertionError
#||
#|| ----------------------------------------------------------------------
#|| Ran 14 tests in 9.418s
#||
#|| FAILED (failures=3)
#|| [Finished in 10 seconds with code 1]