import os
import sys
from decimal import Decimal
from contextlib import contextmanager
import warnings

import numpy as np
from datetime import datetime, timedelta
from astropy import constants as c
from astropy import units as u

import radvel
from scipy.optimize import root

# Normalization.
# RV m/s of a 1.0 Jupiter mass planet tugging on a 1.0
# solar mass star on a 1.0 year orbital period
K_0 = 28.4329

def load_module_from_file(module_name, module_path):
    """Loads a python module from the path of the corresponding file.

        module_name (str): namespace where the python module will be loaded,
            e.g. ````
        module_path (str): path of the python file containing the module
        A valid module object
        ImportError: when the module can't be loaded
        FileNotFoundError: when module_path doesn't exist
    if sys.version_info[0] == 3 and sys.version_info[1] >= 5:
        import importlib.util
        spec = importlib.util.spec_from_file_location(module_name, module_path)
        module = importlib.util.module_from_spec(spec)
    elif sys.version_info[0] == 3 and sys.version_info[1] < 5:
        import importlib.machinery
        loader = importlib.machinery.SourceFileLoader(module_name, module_path)
        module = loader.load_module()
    elif sys.version_info[0] == 2:
        import imp
        module = imp.load_source(module_name, module_path)

    return module

def initialize_posterior(config_file, decorr=False):
    """Initialize Posterior object

    Parse a setup file and initialize the RVModel, Likelihood, Posterior and priors.

        config_file (string): path to config file
        decorr (bool): (optional) decorrelate RVs against columns defined in the decorr_vars list

        tuple: (object representation of config file, radvel.Posterior object)

    system_name = os.path.basename(config_file).split('.')[0]
    P = load_module_from_file(system_name, os.path.abspath(config_file))

    params = P.params
    assert str(params.basis) == "Basis Object <{}>".format(P.fitting_basis), """
Parameters in config file must be converted to fitting basis.

    if decorr:
            decorr_vars = P.decorr_vars
            raise Exception("--decorr option selected,\
 but decorr_vars is not found in your setup file.")
        decorr_vars = []

    for key in params.keys():
        if key.startswith('logjit'):
            msg = """
Fitting log(jitter) is depreciated. Please convert your config
files to initialize 'jit' instead of 'logjit' parameters.
Converting 'logjit' to 'jit' for you now.
            warnings.warn(msg, DeprecationWarning, stacklevel=2)
            newkey = key.replace('logjit', 'jit')
            params[newkey] = radvel.model.Parameter(value=np.exp(params[key].value), vary=params[key].vary)
            del params[key]

    iparams = radvel.basis._copy_params(params)

    # Make sure we don't have duplicate indicies in the DataFrame =

    # initialize RVmodel object
    mod = radvel.RVModel(params, time_base=P.time_base)

    # initialize Likelihood objects for each instrument
    telgrps ='tel').groups
    likes = {}
    for inst in P.instnames:
        assert inst in'tel').groups.keys(), \
            "No data found for instrument '{}'.\nInstruments found in this dataset: {}".format(inst,
        decorr_vectors = {}
        if decorr:
            for d in decorr_vars:
                decorr_vectors[d] =[telgrps[inst]][d].values

            hnames = P.hnames[inst]
            liketype = radvel.likelihood.GPLikelihood
                kernel_name = P.kernel_name[inst]
                # if kernel_name == "Celerite":
                #     liketype = radvel.likelihood.CeleriteLikelihood
                if kernel_name == "Celerite":
                     liketype = radvel.likelihood.CeleriteLikelihood
            except AttributeError:
                kernel_name = "QuasiPer"
        except AttributeError:
            liketype = radvel.likelihood.RVLikelihood
            kernel_name = None
            hnames = None
        likes[inst] = liketype(
  [telgrps[inst]].errvel, hnames=hnames, suffix='_'+inst,
            kernel_name=kernel_name, decorr_vars=decorr_vars,
        likes[inst].params['gamma_'+inst] = iparams['gamma_'+inst]
        likes[inst].params['jit_'+inst] = iparams['jit_'+inst]

    like = radvel.likelihood.CompositeLikelihood(list(likes.values()))

    # Initialize Posterior object
    post = radvel.posterior.Posterior(like)
    post.priors = P.priors

    return P, post

def round_sig(x, sig=2):
    """Round by significant figures
        x (float): number to be rounded
        sig (int): (optional) number of significant figures to retain
        float: x rounded to sig significant figures

    if x == 0:
        return 0.0
    return round(x, sig-int(np.floor(np.log10(abs(x))))-1)

def sigfig(med, errlow, errhigh=None):
    Format values with errors into an equal number of signficant figures.

        med (float): median value
        errlow (float): lower errorbar
        errhigh (float): upper errorbar

        tuple: (med,errlow,errhigh) rounded to the lowest number of significant figures


    if errhigh is None:
        errhigh = errlow

    ndec = Decimal(str(errlow)).as_tuple().exponent
    if abs(Decimal(str(errhigh)).as_tuple().exponent) > abs(ndec):
        ndec = Decimal(str(errhigh)).as_tuple().exponent
    if ndec < -1:
            tmpmed = round(med, abs(ndec))
            p = 0
            if med != 0:
                while tmpmed == 0:
                    tmpmed = round(med, abs(ndec)+p)
                    p += 1
                med = tmpmed
    elif (ndec == -1 and str(errhigh)[-1] == '0') and (ndec == -1 and str(errlow)[-1] == '0') or ndec == 0:
            errlow = int(round_sig(errlow))
            errhigh = int(round(errhigh))
            med = int(round(med))
        med = round(med, abs(ndec))

    return med, errlow, errhigh

def time_print(tdiff):
    """Print time

    Helper function to print time remaining in sensible units.

        tdiff (float): time in seconds

        tuple: (float time, string units)
    units = 'seconds'
    if tdiff > 60:
        tdiff /= 60
        units = 'minutes'
        if tdiff > 60:
            tdiff /= 60
            units = 'hours'
            if tdiff > 24:
                tdiff /= 24
                units = 'days'
    return tdiff, units

def timebin(time, meas, meas_err, binsize):
    """Bin in equal sized time bins

    This routine bins a set of times, measurements, and measurement errors
    into time bins.  All inputs and outputs should be floats or double.
    binsize should have the same units as the time array.
    (from Andrew Howard, ported to Python by BJ Fulton)

        time (array): array of times
        meas (array): array of measurements to be comined
        meas_err (array): array of measurement uncertainties
        binsize (float): width of bins in same units as time array

        tuple: (bin centers, binned measurements, binned uncertainties)

    ind_order = np.argsort(time)
    time = time[ind_order]
    meas = meas[ind_order]
    meas_err = meas_err[ind_order]
    ct = 0
    while ct < len(time):
        ind = np.where((time >= time[ct]) & (time < time[ct]+binsize))[0]
        num = len(ind)
        wt = (1./meas_err[ind])**2.     # weights based in errors
        wt = wt/np.sum(wt)              # normalized weights
        if ct == 0:
            time_out = [np.sum(wt*time[ind])]
            meas_out = [np.sum(wt*meas[ind])]
            meas_err_out = [1./np.sqrt(np.sum(1./(meas_err[ind])**2))]
        ct += num

    return time_out, meas_out, meas_err_out

def bintels(t, vel, err, telvec, binsize=1/2.):
    """Bin velocities by instrument

    Bin RV data with bins of with binsize in the units of t.
    Will not bin data from different telescopes together since there may
    be offsets between them.

        t (array): array of timestamps
        vel (array): array of velocities
        err (array): array of velocity uncertainties
        telvec (array): array of strings corresponding to the instrument name for each velocity
        binsize (float): (optional) width of bin in units of t (default=1/2.)

        tuple: (bin centers, binned measurements, binned uncertainties, binned instrument codes)

    # Bin RV data with bins of with binsize in the units of t.
    # Will not bin data from different telescopes together since there may
    # be offsets between them.

    ntels = len(np.unique(telvec))
    if ntels == 1:
        t_bin, vel_bin, err_bin = timebin(t, vel, err, binsize=binsize)
        return t_bin, vel_bin, err_bin, telvec[0:len(t_bin)]

    uniqorder = np.argsort(np.unique(telvec, return_index=1)[1])
    uniqsort = np.unique(telvec)[uniqorder]
    rvtimes = np.array([])
    rvdat = np.array([])
    rverr = np.array([])
    newtelvec = np.array([])
    for i, tel in enumerate(uniqsort):
        pos = np.where(telvec == tel)
        t_bin, vel_bin, err_bin = timebin(
            t[pos], vel[pos], err[pos], binsize=binsize
        rvtimes = np.hstack((rvtimes, t_bin))
        rvdat = np.hstack((rvdat, vel_bin))
        rverr = np.hstack((rverr, err_bin))
        newtelvec = np.hstack((newtelvec, np.array([tel]*len(t_bin))))

    return rvtimes, rvdat, rverr, newtelvec

def fastbin(x, y, nbins=30):
    """Fast binning

    Fast binning function for equally spaced data

        x (array): independent variable
        y (array): dependent variable
        nbins (int): number of bins

        tuple: (bin centers, binned measurements, binned uncertainties)

    n, _ = np.histogram(x, bins=nbins)
    sy, _ = np.histogram(x, bins=nbins, weights=y)
    sy2, _ = np.histogram(x, bins=nbins, weights=y*y)
    bindat = sy / n
    binerr = np.sqrt(sy2/n - bindat*bindat) / np.sqrt(n)
    bint = (_[1:] + _[:-1])/2.

    binN = n
    pos = binN >= 3  # 0.5 * np.mean(binN)
    bint = bint[pos]
    bindat = bindat[pos]
    binerr = binerr[pos]

    pos = bint > 0
    bint = bint[pos]
    bindat = bindat[pos]
    binerr = binerr[pos]
    return bint, bindat, binerr

def t_to_phase(params, t, num_planet, cat=False):
    """Time to phase

    Convert JD to orbital phase

        params (radvel.params.RVParameters): RV parameters object
        t (array): JD timestamps
        num_planet (int): Which planet's ephemeris to phase fold on
        cat (bool): Concatenate/double the output phase array to extend from 0 to 2

        array: orbital phase at each timestamp

    if ('tc%i' % num_planet) in params:
        timeparam = 'tc%i' % num_planet
    elif ('tp%i' % num_planet) in params:
        timeparam = 'tp%i' % num_planet

    P = params['per%i' % num_planet].value
    tc = params[timeparam].value
    phase = np.mod(t - tc, P)
    phase /= P
    if cat:
        phase = np.concatenate((phase, phase+1))
    return phase

def t_to_phase_vector(vector, t, num_planet, cat=False):

    synth_params = vector.params.basis.v_to_synth(vector)

    P = synth_params[-5+(5*num_planet)][0]
    tc = synth_params[-4+(5*num_planet)][0]
    phase = np.mod(t - tc, P)
    phase /= P
    if cat:
        phase = np.concatenate((phase, phase + 1))
    return phase

def working_directory(dir):
    """Do something in a directory

    Function to use with `with` statements.

       dir (string): name of directory to work in

        >>> with workdir('/temp'):
            # do something within the /temp directory
    cwd = os.getcwd()

def cmd_exists(cmd):
    return any(
        os.access(os.path.join(path, cmd), os.X_OK)
        for path in os.environ["PATH"].split(os.pathsep))

def date2jd(date):
    Convert datetime object to JD"

        date (datetime.datetime): date to convert
        float: Julian date

    jd_td = date - datetime(2000, 1, 1, 12, 0, 0)
    jd = 2451545.0 + jd_td.days + jd_td.seconds/86400.0
    return jd

def jd2date(jd):
    Convert JD to datetime.datetime object

        jd (float): Julian date
        datetime.datetime: calendar date

    mjd = jd - 2400000.5
    td = timedelta(days=mjd)
    dt = datetime(1858, 11, 17, 0, 0, 0) + td

    return dt

def geterr(vec, angular=False):
    Calculate median, 15.9, and 84.1 percentile values
    for a given vector.

        vec (array): vector, usually an MCMC chain for one parameter
        angular (bool [optioanl]): Is this an angular parameter?
            if True vec should be in radians. This will perform
            some checks to ensure proper boundary wrapping.

        tuple: 50, 15.9 and 84.1 percentiles

        vec = vec.values
    except AttributeError:

    if angular:
        val, edges = np.histogram(vec, bins=50)
        med = edges[np.argmax(val)]
        if med > np.radians(90):
            vec[vec < np.radians(0)] = vec[vec < np.radians(0)] + np.radians(360)
        if med <= np.radians(-90):
            vec[vec >= np.radians(0)] = vec[vec >= np.radians(0)] - np.radians(360)
        med = np.median(vec)
        med = np.median(vec)

    s = sorted(vec)
    errlow = med - s[int(0.159*len(s))]
    errhigh = s[int(0.841*len(s))] - med

    return med, errlow, errhigh

def semi_amplitude(Msini, P, Mtotal, e, Msini_units='jupiter'):
    """Compute Doppler semi-amplitude

        Msini (float): mass of planet [Mjup]
        P (float): Orbital period [days]
        Mtotal (float): Mass of star + mass of planet [Msun]
        e (float): eccentricity
        Msini_units (Optional[str]): Units of Msini {'earth','jupiter'}
            default: 'jupiter'

        Doppler semi-amplitude [m/s]


    # convert inputs to array so they work with units
    P = np.array(P)
    Msini = np.array(Msini)
    Mtotal = np.array(Mtotal)
    e = np.array(e)

    P = (P * u.d).to(u.year).value
    if Msini_units.lower() == 'jupiter':
    elif Msini_units.lower() == 'earth':
        Msini = (Msini * u.M_earth).to(u.M_jup).value
        raise Exception("Msini_units must be 'earth', or 'jupiter'")

    K = K_0*(1 - e**2)**-0.5*Msini*P**(-1.0/3.0)*Mtotal**(-2.0 / 3.0)

    return K

def semi_major_axis(P, Mtotal):
    """Semi-major axis

    Kepler's third law

        P (float): Orbital period [days]
        Mtotal (float): Mass [Msun]

        float or array: semi-major axis in AU

    # convert inputs to array so they work with units
    P = np.array(P)
    Mtotal = np.array(Mtotal)

    Mtotal = Mtotal*c.M_sun.value
    P = (P * u.d).to(u.second).value
    G = c.G.value
    a = ((P**2)*G*Mtotal/(4*(np.pi)**2))**(1/3.)
    a = a/

    return a

def Msini(K, P, Mstar, e, Msini_units='earth'):
    """Calculate Msini

    Calculate Msini for a given K, P, stellar mass, and e

        K (float or array: Doppler semi-amplitude [m/s]
        P (float or array): Orbital period [days]
        Mstar (float or array): Mass of star [Msun]
        e (float or array): eccentricity
        Msini_units (Optional[str]): Units of Msini {'earth','jupiter'}
            default: 'earth'

        float or array: Msini [units = Msini_units]


    # convert inputs to array so they work with units
    P = np.array(P)
    Mstar = np.array(Mstar)
    K = np.array(K)
    e = np.array(e)
    G = c.G.value                # added gravitational constant
    Mjup = c.M_jup.value         # added Jupiter's mass
    Msun = c.M_sun.value         # added sun's mass
    Mstar = Mstar*Msun
    Mstar = np.array(Mstar)

    P_year = (P * u.d).to(u.year).value
    P = (P * u.d).to(u.second).value

    # First assume that Mp << Mstar
    Msini = K / K_0 * np.sqrt(1.0 - e ** 2.0) * (Mstar/Msun) ** (2.0 / 3.0) * P_year ** (1 / 3.0)

    # Use correct calculation if any elements are >10% of the stellar mass
    if (np.array(((Msini * u.Mjup).to(u.M_sun) / (Mstar/Msun)).value > 0.10)).any():
        warnings.warn("Mpsini << Mstar assumption broken, correcting Msini calculation.")

        a = K*(((2*(np.pi)*G)/P)**(-1/3.))*np.sqrt(1-(e**2))
        Msini = []
        if isinstance(P, float):
            n_elements = 1
            assert type(K) == type(P) == type(Mstar) == type(e), "All input data types must match."
            assert K.size == P.size == Mstar.size == e.size, "All input arrays must have the same length."
            n_elements = len(P)
        for i in range(n_elements):
            def func(x):
                    return x - a[i]*((Mstar[i]+x)**(2/3.))
                except IndexError:
                    return x - a * ((Mstar + x) ** (2 / 3.))

            sol = root(func, Mjup)

        Msini = np.array(Msini)
        Msini = Msini/Mjup
    if Msini_units.lower() == 'jupiter':
    elif Msini_units.lower() == 'earth':
        Msini = (Msini * u.M_jup).to(u.M_earth).value
        raise Exception("Msini_units must be 'earth', or 'jupiter'")

    return Msini

def density(mass, radius, MR_units='earth'):
    """Compute density from mass and radius

        mass (float): mass [MR_units]
        radius (float): radius [MR_units]
        MR_units (string): (optional) units of mass and radius. Must be 'earth', or 'jupiter' (default 'earth').

        float: density in g/cc

    mass = np.array(mass)
    radius = np.array(radius)

    if MR_units.lower() == 'earth':
        uradius = u.R_earth
        umass = u.M_earth
    elif MR_units.lower() == 'jupiter':
        uradius = u.R_jup
        umass = u.M_jup
        raise Exception("MR_units must be 'earth', or 'jupiter'")

    vol = 4. / 3. * np.pi * (radius * uradius) ** 3
    rho = ((mass * umass / vol).to(u.g / ** 3)).value
    return rho

def draw_models_from_chain(mod, chain, t, nsamples=50):
    """Draw Models from Chain

    Given an MCMC chain of parameters, draw representative parameters
    and synthesize models.

        mod (radvel.RVmodel) : RV model
        chain (DataFrame): pandas DataFrame with different values from MCMC
        t (array): time range over which to synthesize models
        nsamples (int): number of draws

        array: 2D array with the different models as different rows

    chain_samples = chain.ix[np.random.choice(chain.index, nsamples)]
    models = []
    for i in chain_samples.index:
        params = np.array(chain.ix[i, mod.vary_parameters])
        params = mod.array_to_params(params)
        models += [mod.model(params, t)]
    models = np.vstack(models)
    return models