"""
utils.py - Simple utilility funtions used in pyquante2.
"""
import numpy as np
from math import factorial,lgamma
from itertools import combinations_with_replacement,combinations
from functools import reduce

def pairs(it): return combinations_with_replacement(it,2)
def upairs(it): return combinations(it,2)

def fact2(n):
"""
fact2(n) - n!!, double factorial of n
>>> fact2(0)
1
>>> fact2(1)
1
>>> fact2(3)
3
>>> fact2(8)
384
>>> fact2(-1)
1
"""
return reduce(int.__mul__,range(n,0,-2),1)

def norm2(a): return np.dot(a,a)

def binomial(n,k):
"""
Binomial coefficient
>>> binomial(5,2)
10
>>> binomial(10,5)
252
"""
if n == k: return 1
assert n>k, "Attempting to call binomial(%d,%d)" % (n,k)
return factorial(n)//(factorial(k)*factorial(n-k))

def Fgamma(m,x):
"""
Incomplete gamma function
>>> np.isclose(Fgamma(0,0),1.0)
True
"""
SMALL=1e-12
x = max(x,SMALL)
return 0.5*pow(x,-m-0.5)*gamm_inc(m+0.5,x)

# def gamm_inc_scipy(a,x):
#     """
#     Demonstration on how to replace the gamma calls with scipy.special functions.
#     By default, pyquante only requires numpy, but this may change as scipy
#     builds become more stable.
#     >>> np.isclose(gamm_inc_scipy(0.5,1),1.49365)
#     True
#     >>> np.isclose(gamm_inc_scipy(1.5,2),0.6545103)
#     True
#     >>> np.isclose(gamm_inc_scipy(2.5,1e-12),0)
#     True
#     """
#     from scipy.special import gamma,gammainc
#     return gamma(a)*gammainc(a,x)

def gamm_inc(a,x):
"""
Incomple gamma function \gamma; computed from NumRec routine gammp.
>>> np.isclose(gamm_inc(0.5,1),1.49365)
True
>>> np.isclose(gamm_inc(1.5,2),0.6545103)
True
>>> np.isclose(gamm_inc(2.5,1e-12),0)
True
"""
assert (x > 0 and a >= 0), "Invalid arguments in routine gamm_inc: %s,%s" % (x,a)

if x < (a+1.0): #Use the series representation
gam,gln = _gser(a,x)
else: #Use continued fractions
gamc,gln = _gcf(a,x)
gam = 1-gamc
return np.exp(gln)*gam

def _gser(a,x):
"Series representation of Gamma. NumRec sect 6.1."
ITMAX=100
EPS=3.e-7

gln=lgamma(a)
assert(x>=0),'x < 0 in gser'
if x == 0 : return 0,gln

ap = a
delt = sum = 1./a
for i in range(ITMAX):
ap=ap+1.
delt=delt*x/ap
sum=sum+delt
if abs(delt) < abs(sum)*EPS: break
else:
print('a too large, ITMAX too small in gser')
gamser=sum*np.exp(-x+a*np.log(x)-gln)
return gamser,gln

def _gcf(a,x):
"Continued fraction representation of Gamma. NumRec sect 6.1"
ITMAX=100
EPS=3.e-7
FPMIN=1.e-30

gln=lgamma(a)
b=x+1.-a
c=1./FPMIN
d=1./b
h=d
for i in range(1,ITMAX+1):
an=-i*(i-a)
b=b+2.
d=an*d+b
if abs(d) < FPMIN: d=FPMIN
c=b+an/c
if abs(c) < FPMIN: c=FPMIN
d=1./d
delt=d*c
h=h*delt
if abs(delt-1.) < EPS: break
else:
print('a too large, ITMAX too small in gcf')
gammcf=np.exp(-x+a*np.log(x)-gln)*h
return gammcf,gln

def trace2(A,B):
"Return trace(AB) of matrices A and B"
return np.sum(A*B)

def dmat(c,nclosed,nopen=0):
"""Form the density matrix from the first nclosed orbitals of c. If nopen != 0,
add in half the density matrix from the next nopen orbitals.
"""
d = np.dot(c[:,:nclosed],c[:,:nclosed].T)
if nopen > 0:
d += 0.5*np.dot(c[:,nclosed:(nclosed+nopen)],c[:,nclosed:(nclosed+nopen)].T)
return d

def symorth(S):
"Symmetric orthogonalization"
E,U = np.linalg.eigh(S)
n = len(E)
Shalf = np.identity(n,'d')
for i in range(n):
Shalf[i,i] /= np.sqrt(E[i])
return simx(Shalf,U,True)

def canorth(S):
"Canonical orthogonalization U/sqrt(lambda)"
E,U = np.linalg.eigh(S)
for i in range(len(E)):
U[:,i] = U[:,i] / np.sqrt(E[i])
return U

def cholorth(S):
"Cholesky orthogonalization"
return np.linalg.inv(np.linalg.cholesky(S)).T

def simx(A,B,transpose=False):
"Similarity transform B^T(AB) or B(AB^T) (if transpose)"
if transpose:
return np.dot(B,np.dot(A,B.T))
return np.dot(B.T,np.dot(A,B))

def ao2mo(H,C): return simx(H,C)
def mo2ao(H,C,S): return simx(H,np.dot(S,C),transpose=True)

def geigh(H,S):
"Solve the generalized eigensystem Hc = ESc"
A = cholorth(S)
E,U = np.linalg.eigh(simx(H,A))
return E,np.dot(A,U)

def parseline(line,format):
"""\
Given a line (a string actually) and a short string telling
how to format it, return a list of python objects that result.

The format string maps words (as split by line.split()) into
python code:
x   ->    Nothing; skip this word
s   ->    Return this word as a string
i   ->    Return this word as an int
d   ->    Return this word as an int
f   ->    Return this word as a float

Basic parsing of strings:
>>> parseline('Hello, World','ss')
['Hello,', 'World']

You can use 'x' to skip a record; you also don't have to parse
every record:
>>> parseline('1 2 3 4','xdd')
[2, 3]

>>> parseline('C1   0.0  0.0 0.0','sfff')
['C1', 0.0, 0.0, 0.0]

Should this return an empty list?
>>> parseline('This line wont be parsed','xx')
"""
xlat = {'x':None,'s':str,'f':float,'d':int,'i':int}
result = []
words = line.split()
for i in range(len(format)):
f = format[i]
trans = xlat.get(f,None)
if trans: result.append(trans(words[i]))
if len(result) == 0: return None
if len(result) == 1: return result[0]
return result

def colorscale(mag, cmin, cmax):
"""
Return a tuple of floats between 0 and 1 for R, G, and B.
From Python Cookbook (9.11?)
"""
# Normalize to 0-1
try:
x = float(mag-cmin)/(cmax-cmin)
except ZeroDivisionError:
x = 0.5  # cmax == cmin
blue = min((max((4*(0.75-x), 0.)), 1.))
red = min((max((4*(x-0.25), 0.)), 1.))
green = min((max((4*abs(x-0.5)-1., 0.)), 1.))
return red, green, blue

#Todo: replace with np.isclose
#def isnear(a,b,tol=1e-6): return abs(a-b) < tol

if __name__ == '__main__':
import doctest
doctest.testmod()