#!/usr/bin/env python
# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Qiming Sun <osirpt.sun@gmail.com>
#

'''
Some helper functions
'''

import os, sys
import warnings
import imp
import tempfile
import functools
import itertools
import ctypes
import numpy
import h5py
from threading import Thread
from multiprocessing import Queue, Process
try:
    from concurrent.futures import ThreadPoolExecutor
except ImportError:
    ThreadPoolExecutor = None

from pyscf.lib import param
from pyscf import __config__

if h5py.version.version[:4] == '2.2.':
    sys.stderr.write('h5py-%s is found in your environment. '
                     'h5py-%s has bug in threading mode.\n'
                     'Async-IO is disabled.\n' % ((h5py.version.version,)*2))
if h5py.version.version[:2] == '3.':
    h5py.get_config().default_file_mode = 'a'

c_double_p = ctypes.POINTER(ctypes.c_double)
c_int_p = ctypes.POINTER(ctypes.c_int)
c_null_ptr = ctypes.POINTER(ctypes.c_void_p)

def load_library(libname):
# numpy 1.6 has bug in ctypeslib.load_library, see numpy/distutils/misc_util.py
    if '1.6' in numpy.__version__:
        if (sys.platform.startswith('linux') or
            sys.platform.startswith('gnukfreebsd')):
            so_ext = '.so'
        elif sys.platform.startswith('darwin'):
            so_ext = '.dylib'
        elif sys.platform.startswith('win'):
            so_ext = '.dll'
        else:
            raise OSError('Unknown platform')
        libname_so = libname + so_ext
        return ctypes.CDLL(os.path.join(os.path.dirname(__file__), libname_so))
    else:
        _loaderpath = os.path.dirname(__file__)
        return numpy.ctypeslib.load_library(libname, _loaderpath)

#Fixme, the standard resouce module gives wrong number when objects are released
#see http://fa.bianp.net/blog/2013/different-ways-to-get-memory-consumption-or-lessons-learned-from-memory_profiler/#fn:1
#or use slow functions as memory_profiler._get_memory did
CLOCK_TICKS = os.sysconf("SC_CLK_TCK")
PAGESIZE = os.sysconf("SC_PAGE_SIZE")
def current_memory():
    '''Return the size of used memory and allocated virtual memory (in MB)'''
    #import resource
    #return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1000
    if sys.platform.startswith('linux'):
        with open("/proc/%s/statm" % os.getpid()) as f:
            vms, rss = [int(x)*PAGESIZE for x in f.readline().split()[:2]]
            return rss/1e6, vms/1e6
    else:
        return 0, 0

def num_threads(n=None):
    '''Set the number of OMP threads.  If argument is not specified, the
    function will return the total number of available OMP threads.

    It's recommended to call this function to set OMP threads than
    "os.environ['OMP_NUM_THREADS'] = int(n)". This is because environment
    variables like OMP_NUM_THREADS were read when a module was imported. They
    cannot be reset through os.environ after the module was loaded.

    Examples:

    >>> from pyscf import lib
    >>> print(lib.num_threads())
    8
    >>> lib.num_threads(4)
    4
    >>> print(lib.num_threads())
    4
    '''
    from pyscf.lib.numpy_helper import _np_helper
    if n is not None:
        _np_helper.set_omp_threads.restype = ctypes.c_int
        threads = _np_helper.set_omp_threads(ctypes.c_int(int(n)))
        if threads == 0:
            warnings.warn('OpenMP is not available. '
                          'Setting omp_threads to %s has no effects.' % n)
        return threads
    else:
        _np_helper.get_omp_threads.restype = ctypes.c_int
        return _np_helper.get_omp_threads()

class with_omp_threads(object):
    '''Using this macro to create a temporary context in which the number of
    OpenMP threads are set to the required value. When the program exits the
    context, the number OpenMP threads will be restored.

    Args:
        nthreads : int

    Examples:

    >>> from pyscf import lib
    >>> print(lib.num_threads())
    8
    >>> with lib.with_omp_threads(2):
    ...     print(lib.num_threads())
    2
    >>> print(lib.num_threads())
    8
    '''
    def __init__(self, nthreads=None):
        self.nthreads = nthreads
        self.sys_threads = None
    def __enter__(self):
        if self.nthreads is not None and self.nthreads >= 1:
            self.sys_threads = num_threads()
            num_threads(self.nthreads)
        return self
    def __exit__(self, type, value, traceback):
        if self.sys_threads is not None:
            num_threads(self.sys_threads)


def c_int_arr(m):
    npm = numpy.array(m).flatten('C')
    arr = (ctypes.c_int * npm.size)(*npm)
    # cannot return LP_c_double class,
    #Xreturn npm.ctypes.data_as(c_int_p), which destructs npm before return
    return arr
def f_int_arr(m):
    npm = numpy.array(m).flatten('F')
    arr = (ctypes.c_int * npm.size)(*npm)
    return arr
def c_double_arr(m):
    npm = numpy.array(m).flatten('C')
    arr = (ctypes.c_double * npm.size)(*npm)
    return arr
def f_double_arr(m):
    npm = numpy.array(m).flatten('F')
    arr = (ctypes.c_double * npm.size)(*npm)
    return arr


def member(test, x, lst):
    for l in lst:
        if test(x, l):
            return True
    return False

def remove_dup(test, lst, from_end=False):
    if test is None:
        return set(lst)
    else:
        if from_end:
            lst = list(reversed(lst))
        seen = []
        for l in lst:
            if not member(test, l, seen):
                seen.append(l)
        return seen

def remove_if(test, lst):
    return [x for x in lst if not test(x)]

def find_if(test, lst):
    for l in lst:
        if test(l):
            return l
    raise ValueError('No element of the given list matches the test condition.')

def arg_first_match(test, lst):
    for i,x in enumerate(lst):
        if test(x):
            return i
    raise ValueError('No element of the given list matches the test condition.')

def _balanced_partition(cum, ntasks):
    segsize = float(cum[-1]) / ntasks
    bounds = numpy.arange(ntasks+1) * segsize
    displs = abs(bounds[:,None] - cum).argmin(axis=1)
    return displs

def _blocksize_partition(cum, blocksize):
    n = len(cum) - 1
    displs = [0]
    if n == 0:
        return displs

    p0 = 0
    for i in range(1, n):
        if cum[i+1]-cum[p0] > blocksize:
            displs.append(i)
            p0 = i
    displs.append(n)
    return displs

def flatten(lst):
    '''flatten nested lists
    x[0] + x[1] + x[2] + ...

    Examples:

    >>> flatten([[0, 2], [1], [[9, 8, 7]]])
    [0, 2, 1, [9, 8, 7]]
    '''
    return list(itertools.chain.from_iterable(lst))

def prange(start, end, step):
    '''This function splits the number sequence between "start" and "end"
    using uniform "step" length. It yields the boundary (start, end) for each
    fragment.

    Examples:

    >>> for p0, p1 in lib.prange(0, 8, 2):
    ...    print(p0, p1)
    (0, 2)
    (2, 4)
    (4, 6)
    (6, 8)
    '''
    if start < end:
        for i in range(start, end, step):
            yield i, min(i+step, end)

def prange_tril(start, stop, blocksize):
    '''Similar to :func:`prange`, yeilds start (p0) and end (p1) with the
    restriction p1*(p1+1)/2-p0*(p0+1)/2 < blocksize

    Examples:

    >>> for p0, p1 in lib.prange_tril(0, 10, 25):
    ...     print(p0, p1)
    (0, 6)
    (6, 9)
    (9, 10)
    '''
    if start >= stop:
        return []
    idx = numpy.arange(start, stop+1)
    cum_costs = idx*(idx+1)//2 - start*(start+1)//2
    displs = [x+start for x in _blocksize_partition(cum_costs, blocksize)]
    return zip(displs[:-1], displs[1:])


def index_tril_to_pair(ij):
    '''Given tril-index ij, compute the pair indices (i,j) which satisfy
    ij = i * (i+1) / 2 + j
    '''
    i = (numpy.sqrt(2*ij+.25) - .5 + 1e-7).astype(int)
    j = ij - i*(i+1)//2
    return i, j


def tril_product(*iterables, **kwds):
    '''Cartesian product in lower-triangular form for multiple indices

    For a given list of indices (`iterables`), this function yields all
    indices such that the sub-indices given by the kwarg `tril_idx` satisfy a
    lower-triangular form.  The lower-triangular form satisfies:

    .. math:: i[tril_idx[0]] >= i[tril_idx[1]] >= ... >= i[tril_idx[len(tril_idx)-1]]

    Args:
        *iterables: Variable length argument list of indices for the cartesian product
        **kwds: Arbitrary keyword arguments.  Acceptable keywords include:
            repeat (int): Number of times to repeat the iterables
            tril_idx (array_like): Indices to put into lower-triangular form.

    Yields:
        product (tuple): Tuple in lower-triangular form.

    Examples:
        Specifying no `tril_idx` is equivalent to just a cartesian product.

        >>> list(tril_product(range(2), repeat=2))
        [(0, 0), (0, 1), (1, 0), (1, 1)]

        We can specify only sub-indices to satisfy a lower-triangular form:

        >>> list(tril_product(range(2), repeat=3, tril_idx=[1,2]))
        [(0, 0, 0), (0, 1, 0), (0, 1, 1), (1, 0, 0), (1, 1, 0), (1, 1, 1)]

        We specify all indices to satisfy a lower-triangular form, useful for iterating over
        the symmetry unique elements of occupied/virtual orbitals in a 3-particle operator:

        >>> list(tril_product(range(3), repeat=3, tril_idx=[0,1,2]))
        [(0, 0, 0), (1, 0, 0), (1, 1, 0), (1, 1, 1), (2, 0, 0), (2, 1, 0), (2, 1, 1), (2, 2, 0), (2, 2, 1), (2, 2, 2)]
    '''
    repeat = kwds.get('repeat', 1)
    tril_idx = kwds.get('tril_idx', [])
    niterables = len(iterables) * repeat
    ntril_idx = len(tril_idx)

    assert ntril_idx <= niterables, 'Cant have a greater number of tril indices than iterables!'
    if ntril_idx > 0:
        assert numpy.max(tril_idx) < niterables, 'Tril index out of bounds for %d iterables! idx = %s' % \
                                                 (niterables, tril_idx)
    for tup in itertools.product(*iterables, repeat=repeat):
        if ntril_idx == 0:
            yield tup
            continue

        if all([tup[tril_idx[i]] >= tup[tril_idx[i+1]] for i in range(ntril_idx-1)]):
            yield tup
        else:
            pass

def square_mat_in_trilu_indices(n):
    '''Return a n x n symmetric index matrix, in which the elements are the
    indices of the unique elements of a tril vector
    [0 1 3 ... ]
    [1 2 4 ... ]
    [3 4 5 ... ]
    [...       ]
    '''
    idx = numpy.tril_indices(n)
    tril2sq = numpy.zeros((n,n), dtype=int)
    tril2sq[idx[0],idx[1]] = tril2sq[idx[1],idx[0]] = numpy.arange(n*(n+1)//2)
    return tril2sq

class capture_stdout(object):
    '''redirect all stdout (c printf & python print) into a string

    Examples:

    >>> import os
    >>> from pyscf import lib
    >>> with lib.capture_stdout() as out:
    ...     os.system('ls')
    >>> print(out.read())
    '''
    #TODO: handle stderr
    def __enter__(self):
        sys.stdout.flush()
        self._contents = None
        self.old_stdout_fileno = sys.stdout.fileno()
        self.bak_stdout_fd = os.dup(self.old_stdout_fileno)
        self.ftmp = tempfile.NamedTemporaryFile(dir=param.TMPDIR)
        os.dup2(self.ftmp.file.fileno(), self.old_stdout_fileno)
        return self
    def __exit__(self, type, value, traceback):
        sys.stdout.flush()
        self.ftmp.file.seek(0)
        self._contents = self.ftmp.file.read()
        self.ftmp.close()
        os.dup2(self.bak_stdout_fd, self.old_stdout_fileno)
        os.close(self.bak_stdout_fd)
    def read(self):
        if self._contents:
            return self._contents
        else:
            sys.stdout.flush()
            self.ftmp.file.seek(0)
            return self.ftmp.file.read()
ctypes_stdout = capture_stdout

class quite_run(object):
    '''capture all stdout (c printf & python print) but output nothing

    Examples:

    >>> import os
    >>> from pyscf import lib
    >>> with lib.quite_run():
    ...     os.system('ls')
    '''
    def __enter__(self):
        sys.stdout.flush()
        #TODO: to handle the redirected stdout e.g. StringIO()
        self.old_stdout_fileno = sys.stdout.fileno()
        self.bak_stdout_fd = os.dup(self.old_stdout_fileno)
        self.fnull = open(os.devnull, 'wb')
        os.dup2(self.fnull.fileno(), self.old_stdout_fileno)
    def __exit__(self, type, value, traceback):
        sys.stdout.flush()
        os.dup2(self.bak_stdout_fd, self.old_stdout_fileno)
        self.fnull.close()


# from pygeocoder
# this decorator lets me use methods as both static and instance methods
# In contrast to classmethod, when obj.function() is called, the first
# argument is obj in omnimethod rather than obj.__class__ in classmethod
class omnimethod(object):
    def __init__(self, func):
        self.func = func

    def __get__(self, instance, owner):
        return functools.partial(self.func, instance)


SANITY_CHECK = getattr(__config__, 'SANITY_CHECK', True)
class StreamObject(object):
    '''For most methods, there are three stream functions to pipe computing stream:

    1 ``.set_`` function to update object attributes, eg
    ``mf = scf.RHF(mol).set(conv_tol=1e-5)`` is identical to proceed in two steps
    ``mf = scf.RHF(mol); mf.conv_tol=1e-5``

    2 ``.run`` function to execute the kenerl function (the function arguments
    are passed to kernel function).  If keyword arguments is given, it will first
    call ``.set`` function to update object attributes then execute the kernel
    function.  Eg
    ``mf = scf.RHF(mol).run(dm_init, conv_tol=1e-5)`` is identical to three steps
    ``mf = scf.RHF(mol); mf.conv_tol=1e-5; mf.kernel(dm_init)``

    3 ``.apply`` function to apply the given function/class to the current object
    (function arguments and keyword arguments are passed to the given function).
    Eg
    ``mol.apply(scf.RHF).run().apply(mcscf.CASSCF, 6, 4, frozen=4)`` is identical to
    ``mf = scf.RHF(mol); mf.kernel(); mcscf.CASSCF(mf, 6, 4, frozen=4)``
    '''

    verbose = 0
    stdout = sys.stdout
    _keys = set(['verbose', 'stdout'])

    def kernel(self, *args, **kwargs):
        '''
        Kernel function is the main driver of a method.  Every method should
        define the kernel function as the entry of the calculation.  Note the
        return value of kernel function is not strictly defined.  It can be
        anything related to the method (such as the energy, the wave-function,
        the DFT mesh grids etc.).
        '''
        pass

    def pre_kernel(self, envs):
        '''
        A hook to be run before the main body of kernel function is executed.
        Internal variables are exposed to pre_kernel through the "envs"
        dictionary.  Return value of pre_kernel function is not required.
        '''
        pass

    def post_kernel(self, envs):
        '''
        A hook to be run after the main body of the kernel function.  Internal
        variables are exposed to post_kernel through the "envs" dictionary.
        Return value of post_kernel function is not required.
        '''
        pass

    def run(self, *args, **kwargs):
        '''
        Call the kernel function of current object.  `args` will be passed
        to kernel function.  `kwargs` will be used to update the attributes of
        current object.  The return value of method run is the object itself.
        This allows a series of functions/methods to be executed in pipe.
        '''
        self.set(**kwargs)
        self.kernel(*args)
        return self

    def set(self, *args, **kwargs):
        '''
        Update the attributes of the current object.  The return value of
        method set is the object itself.  This allows a series of
        functions/methods to be executed in pipe.
        '''
        if args:
            warnings.warn('method set() only supports keyword arguments.\n'
                          'Arguments %s are ignored.' % args)
        #if getattr(self, '_keys', None):
        #    for k,v in kwargs.items():
        #        setattr(self, k, v)
        #        if k not in self._keys:
        #            sys.stderr.write('Warning: %s does not have attribute %s\n'
        #                             % (self.__class__, k))
        #else:
        for k,v in kwargs.items():
            setattr(self, k, v)
        return self

    # An alias to .set method
    __call__ = set

    def apply(self, fn, *args, **kwargs):
        '''
        Apply the fn to rest arguments:  return fn(*args, **kwargs).  The
        return value of method set is the object itself.  This allows a series
        of functions/methods to be executed in pipe.
        '''
        return fn(self, *args, **kwargs)

#    def _format_args(self, args, kwargs, kernel_kw_lst):
#        args1 = [kwargs.pop(k, v) for k, v in kernel_kw_lst]
#        return args + args1[len(args):], kwargs

    def check_sanity(self):
        '''
        Check input of class/object attributes, check whether a class method is
        overwritten.  It does not check the attributes which are prefixed with
        "_".  The
        return value of method set is the object itself.  This allows a series
        of functions/methods to be executed in pipe.
        '''
        if (SANITY_CHECK and
            self.verbose > 0 and  # logger.QUIET
            getattr(self, '_keys', None)):
            check_sanity(self, self._keys, self.stdout)
        return self

    def view(self, cls):
        '''New view of object with the same attributes.'''
        obj = cls.__new__(cls)
        obj.__dict__.update(self.__dict__)
        return obj

_warn_once_registry = {}
def check_sanity(obj, keysref, stdout=sys.stdout):
    '''Check misinput of class attributes, check whether a class method is
    overwritten.  It does not check the attributes which are prefixed with
    "_".
    '''
    objkeys = [x for x in obj.__dict__ if not x.startswith('_')]
    keysub = set(objkeys) - set(keysref)
    if keysub:
        class_attr = set(dir(obj.__class__))
        keyin = keysub.intersection(class_attr)
        if keyin:
            msg = ('Overwritten attributes  %s  of %s\n' %
                   (' '.join(keyin), obj.__class__))
            if msg not in _warn_once_registry:
                _warn_once_registry[msg] = 1
                sys.stderr.write(msg)
                if stdout is not sys.stdout:
                    stdout.write(msg)
        keydiff = keysub - class_attr
        if keydiff:
            msg = ('%s does not have attributes  %s\n' %
                   (obj.__class__, ' '.join(keydiff)))
            if msg not in _warn_once_registry:
                _warn_once_registry[msg] = 1
                sys.stderr.write(msg)
                if stdout is not sys.stdout:
                    stdout.write(msg)
    return obj

def with_doc(doc):
    '''Use this decorator to add doc string for function

        @with_doc(doc)
        def fn:
            ...

    is equivalent to

        fn.__doc__ = doc
    '''
    def fn_with_doc(fn):
        fn.__doc__ = doc
        return fn
    return fn_with_doc

def alias(fn, alias_name=None):
    '''
    The statement "fn1 = alias(fn)" in a class is equivalent to define the
    following method in the class:

    .. code-block:: python
        def fn1(self, *args, **kwargs):
            return self.fn(*args, **kwargs)

    Using alias function instead of fn1 = fn because some methods may be
    overloaded in the child class. Using "alias" can make sure that the
    overloaded mehods were called when calling the aliased method.
    '''
    fname = fn.__name__
    def aliased_fn(self, *args, **kwargs):
        return getattr(self, fname)(*args, **kwargs)

    if alias_name is not None:
        aliased_fn.__name__ = alias_name

    doc_str = 'An alias to method %s\n' % fname
    if sys.version_info >= (3,):
        from inspect import signature
        sig = str(signature(fn))
        if alias_name is None:
            doc_str += 'Function Signature: %s\n' % sig
        else:
            doc_str += 'Function Signature: %s%s\n' % (alias_name, sig)
    doc_str += '----------------------------------------\n\n'

    if fn.__doc__ is not None:
        doc_str += fn.__doc__

    aliased_fn.__doc__ = doc_str
    return aliased_fn

def class_as_method(cls):
    '''
    The statement "fn1 = alias(Class)" is equivalent to:

    .. code-block:: python
        def fn1(self, *args, **kwargs):
            return Class(self, *args, **kwargs)
    '''
    def fn(obj, *args, **kwargs):
        return cls(obj, *args, **kwargs)
    fn.__doc__ = cls.__doc__
    fn.__name__ = cls.__name__
    fn.__module__ = cls.__module__
    return fn

def overwrite_mro(obj, mro):
    '''A hacky function to overwrite the __mro__ attribute'''
    class HackMRO(type):
        pass
# Overwrite type.mro function so that Temp class can use the given mro
    HackMRO.mro = lambda self: mro
    #if sys.version_info < (3,):
    #    class Temp(obj.__class__):
    #        __metaclass__ = HackMRO
    #else:
    #    class Temp(obj.__class__, metaclass=HackMRO):
    #        pass
    Temp = HackMRO(obj.__class__.__name__, obj.__class__.__bases__, obj.__dict__)
    obj = Temp()
# Delete mro function otherwise all subclass of Temp are not able to
# resolve the right mro
    del(HackMRO.mro)
    return obj

def izip(*args):
    '''python2 izip == python3 zip'''
    if sys.version_info < (3,):
        return itertools.izip(*args)
    else:
        return zip(*args)

class ProcessWithReturnValue(Process):
    def __init__(self, group=None, target=None, name=None, args=(),
                 kwargs=None):
        self._q = Queue()
        self._e = None
        def qwrap(*args, **kwargs):
            try:
                self._q.put(target(*args, **kwargs))
            except BaseException as e:
                self._e = e
                raise e
        Process.__init__(self, group, qwrap, name, args, kwargs)
    def join(self):
        Process.join(self)
        if self._e is not None:
            raise ProcessRuntimeError('Error on process %s:\n%s' % (self, self._e))
        else:
            return self._q.get()
    get = join

class ProcessRuntimeError(RuntimeError):
    pass

class ThreadWithReturnValue(Thread):
    def __init__(self, group=None, target=None, name=None, args=(),
                 kwargs=None):
        self._q = Queue()
        self._e = None
        def qwrap(*args, **kwargs):
            try:
                self._q.put(target(*args, **kwargs))
            except BaseException as e:
                self._e = e
                raise e
        Thread.__init__(self, group, qwrap, name, args, kwargs)
    def join(self):
        Thread.join(self)
        if self._e is not None:
            raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, self._e))
        else:
# Note: If the return value of target is huge, Queue.get may raise
# SystemError: NULL result without error in PyObject_Call
# It is because return value is cached somewhere by pickle but pickle is
# unable to handle huge amount of data.
            return self._q.get()
    get = join

class ThreadWithTraceBack(Thread):
    def __init__(self, group=None, target=None, name=None, args=(),
                 kwargs=None):
        self._e = None
        def qwrap(*args, **kwargs):
            try:
                target(*args, **kwargs)
            except BaseException as e:
                self._e = e
                raise e
        Thread.__init__(self, group, qwrap, name, args, kwargs)
    def join(self):
        Thread.join(self)
        if self._e is not None:
            raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, self._e))

class ThreadRuntimeError(RuntimeError):
    pass

def background_thread(func, *args, **kwargs):
    '''applying function in background'''
    thread = ThreadWithReturnValue(target=func, args=args, kwargs=kwargs)
    thread.start()
    return thread

def background_process(func, *args, **kwargs):
    '''applying function in background'''
    thread = ProcessWithReturnValue(target=func, args=args, kwargs=kwargs)
    thread.start()
    return thread

bg = background = bg_thread = background_thread
bp = bg_process = background_process

ASYNC_IO = getattr(__config__, 'ASYNC_IO', True)
class call_in_background(object):
    '''Within this macro, function(s) can be executed asynchronously (the
    given functions are executed in background).

    Attributes:
        sync (bool): Whether to run in synchronized mode.  The default value
            is False (asynchoronized mode).

    Examples:

    >>> with call_in_background(fun) as async_fun:
    ...     async_fun(a, b)  # == fun(a, b)
    ...     do_something_else()

    >>> with call_in_background(fun1, fun2) as (afun1, afun2):
    ...     afun2(a, b)
    ...     do_something_else()
    ...     afun2(a, b)
    ...     do_something_else()
    ...     afun1(a, b)
    ...     do_something_else()
    '''

    def __init__(self, *fns, **kwargs):
        self.fns = fns
        self.executor = None
        self.handlers = [None] * len(self.fns)
        self.sync = kwargs.get('sync', not ASYNC_IO)

    if h5py.version.version[:4] == '2.2.': # h5py-2.2.* has bug in threading mode
        # Disable back-ground mode
        def __enter__(self):
            if len(self.fns) == 1:
                return self.fns[0]
            else:
                return self.fns

    else:
        def __enter__(self):
            fns = self.fns
            handlers = self.handlers
            ntasks = len(self.fns)

            if self.sync or imp.lock_held():
# Some modules like nosetests, coverage etc
#   python -m unittest test_xxx.py  or  nosetests test_xxx.py
# hang when Python multi-threading was used in the import stage due to (Python
# import lock) bug in the threading module.  See also
# https://github.com/paramiko/paramiko/issues/104
# https://docs.python.org/2/library/threading.html#importing-in-threaded-code
# Disable the asynchoronous mode for safe importing
                def def_async_fn(i):
                    return fns[i]

            elif ThreadPoolExecutor is None: # async mode, old python
                def def_async_fn(i):
                    def async_fn(*args, **kwargs):
                        if self.handlers[i] is not None:
                            self.handlers[i].join()
                        self.handlers[i] = ThreadWithTraceBack(target=fns[i], args=args,
                                                               kwargs=kwargs)
                        self.handlers[i].start()
                        return self.handlers[i]
                    return async_fn

            else: # multiple executors in async mode, python 2.7.12 or newer
                executor = self.executor = ThreadPoolExecutor(max_workers=ntasks)
                def def_async_fn(i):
                    def async_fn(*args, **kwargs):
                        if handlers[i] is not None:
                            try:
                                handlers[i].result()
                            except Exception as e:
                                raise ThreadRuntimeError('Error on thread %s:\n%s'
                                                         % (self, e))
                        handlers[i] = executor.submit(fns[i], *args, **kwargs)
                        return handlers[i]
                    return async_fn

            if len(self.fns) == 1:
                return def_async_fn(0)
            else:
                return [def_async_fn(i) for i in range(ntasks)]

    def __exit__(self, type, value, traceback):
        for handler in self.handlers:
            if handler is not None:
                try:
                    if ThreadPoolExecutor is None:
                        handler.join()
                    else:
                        handler.result()
                except Exception as e:
                    raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, e))

        if self.executor is not None:
            self.executor.shutdown(wait=True)


class H5TmpFile(h5py.File):
    '''Create and return an HDF5 temporary file.

    Kwargs:
        filename : str or None
            If a string is given, an HDF5 file of the given filename will be
            created. The temporary file will exist even if the H5TmpFile
            object is released.  If nothing is specified, the HDF5 temporary
            file will be deleted when the H5TmpFile object is released.

    The return object is an h5py.File object. The file will be automatically
    deleted when it is closed or the object is released (unless filename is
    specified).

    Examples:

    >>> from pyscf import lib
    >>> ftmp = lib.H5TmpFile()
    '''
    def __init__(self, filename=None, mode='a', *args, **kwargs):
        if filename is None:
            tmpfile = tempfile.NamedTemporaryFile(dir=param.TMPDIR)
            filename = tmpfile.name
        h5py.File.__init__(self, filename, mode, *args, **kwargs)
#FIXME: Does GC flush/close the HDF5 file when releasing the resource?
# To make HDF5 file reusable, file has to be closed or flushed
    def __del__(self):
        try:
            self.close()
        except AttributeError:  # close not defined in old h5py
            pass
        except ValueError:  # if close() is called twice
            pass
        except ImportError:  # exit program before de-referring the object
            pass

def fingerprint(a):
    '''Fingerprint of numpy array'''
    a = numpy.asarray(a)
    return numpy.dot(numpy.cos(numpy.arange(a.size)), a.ravel())
finger = fp = fingerprint


def ndpointer(*args, **kwargs):
    base = numpy.ctypeslib.ndpointer(*args, **kwargs)

    @classmethod
    def from_param(cls, obj):
        if obj is None:
            return obj
        return base.from_param(obj)
    return type(base.__name__, (base,), {'from_param': from_param})


# A tag to label the derived Scanner class
class SinglePointScanner: pass
class GradScanner:
    def __init__(self, g):
        self.__dict__.update(g.__dict__)
        self.base = g.base.as_scanner()
    @property
    def e_tot(self):
        return self.base.e_tot
    @e_tot.setter
    def e_tot(self, x):
        self.base.e_tot = x

    @property
    def converged(self):
# Some base methods like MP2 does not have the attribute converged
        conv = getattr(self.base, 'converged', True)
        return conv

class temporary_env(object):
    '''Within the context of this macro, the attributes of the object are
    temporarily updated. When the program goes out of the scope of the
    context, the original value of each attribute will be restored.

    Examples:

    >>> with temporary_env(lib.param, LIGHT_SPEED=15., BOHR=2.5):
    ...     print(lib.param.LIGHT_SPEED, lib.param.BOHR)
    15. 2.5
    >>> print(lib.param.LIGHT_SPEED, lib.param.BOHR)
    137.03599967994 0.52917721092
    '''
    def __init__(self, obj, **kwargs):
        self.obj = obj

        # Should I skip the keys which are not presented in obj?
        #keys = [key for key in kwargs.keys() if hasattr(obj, key)]
        #self.env_bak = [(key, getattr(obj, key, 'TO_DEL')) for key in keys]
        #self.env_new = [(key, kwargs[key]) for key in keys]

        self.env_bak = [(key, getattr(obj, key, 'TO_DEL')) for key in kwargs]
        self.env_new = [(key, kwargs[key]) for key in kwargs]

    def __enter__(self):
        for k, v in self.env_new:
            setattr(self.obj, k, v)
        return self

    def __exit__(self, type, value, traceback):
        for k, v in self.env_bak:
            if isinstance(v, str) and v == 'TO_DEL':
                delattr(self.obj, k)
            else:
                setattr(self.obj, k, v)

class light_speed(temporary_env):
    '''Within the context of this macro, the environment varialbe LIGHT_SPEED
    can be customized.

    Examples:

    >>> with light_speed(15.):
    ...     print(lib.param.LIGHT_SPEED)
    15.
    >>> print(lib.param.LIGHT_SPEED)
    137.03599967994
    '''
    def __init__(self, c):
        temporary_env.__init__(self, param, LIGHT_SPEED=c)
        self.c = c
    def __enter__(self):
        temporary_env.__enter__(self)
        return self.c


if __name__ == '__main__':
    for i,j in prange_tril(0, 90, 300):
        print(i, j, j*(j+1)//2-i*(i+1)//2)