import numpy as np
from numpy.testing import assert_allclose, assert_raises
from nose import SkipTest

try:
    # Python 3
    from urllib.error import URLError
    ConnectionError = ConnectionResetError
except ImportError:
    # Python 2
    from urllib2 import URLError
    from socket import error as ConnectionError

from .. import RRLyraeTemplateModeler, RRLyraeTemplateModelerMultiband
from ...datasets import fetch_rrlyrae_templates, fetch_rrlyrae
from scipy.interpolate import UnivariateSpline


def test_basic_template_model():
    template_id = 25

    try:
        templates = fetch_rrlyrae_templates()
    except(URLError, ConnectionError):
        raise SkipTest("No internet connection: "
                       "data download test skipped")

    phase, y = templates.get_template(templates.ids[template_id])
    model = UnivariateSpline(phase, y, s=0, k=5)

    theta = [17, 0.5, 0.3]
    period = 0.63
    rng = np.random.RandomState(0)
    t = rng.rand(20)
    mag = theta[0] + theta[1] * model((t / period - theta[2]) % 1)

    model = RRLyraeTemplateModeler('ugriz')
    model.fit(t, mag, 1)

    # check that the model matches what we expect
    assert_allclose(model._model(t, theta, period, template_id), mag)

    # check that the optimized model matches the input
    for use_gradient in [True, False]:
        theta_fit = model._optimize(period, template_id, use_gradient)
        assert_allclose(theta, theta_fit, rtol=1E-4)

    # check that the chi2 is near zero
    assert_allclose(model._chi2(theta_fit, period, template_id), 0,
                    atol=1E-8)


def test_multiband_fit():
    # TODO: this is a long test.
    # We could artificially limit the number of templates to make it faster
    try:
        rrlyrae = fetch_rrlyrae()
    except(URLError, ConnectionError):
        raise SkipTest("No internet connection: "
                       "data download test skipped")

    t, y, dy, filts = rrlyrae.get_lightcurve(rrlyrae.ids[0])
    t = t[::10]
    y = y[::10]
    dy = dy[::10]
    filts = filts[::10]
    period = rrlyrae.get_metadata(rrlyrae.ids[0])['P']

    tfit = np.linspace(0, 5 * period, 99)
    filts_fit = np.array(list('ugriz'))[:, None]

    model = RRLyraeTemplateModelerMultiband().fit(t, y, dy, filts)
    yfit_all = model.predict(tfit, filts_fit, period)

    yfit_band = []
    for filt in 'ugriz':
        mask = (filts == filt)
        model = RRLyraeTemplateModeler(filt)
        model.fit(t[mask], y[mask], dy[mask])
        yfit_band.append(model.predict(tfit, period))

    assert_allclose(yfit_all, yfit_band)


def test_bad_args():
    try:
        assert_raises(ValueError, RRLyraeTemplateModeler, filts='abc')
    except(URLError, ConnectionError):
        raise SkipTest("No internet connection: "
                       "data download test skipped")