import pytest import numpy as np import matplotlib.pyplot as plt from numpy.testing import assert_almost_equal, assert_array_equal from astropy import units as u from astropy.time import Time from astropy.stats.bls import BoxLeastSquares from ..lightcurve import LightCurve from ..periodogram import Periodogram from ..utils import LightkurveWarning import sys def test_periodogram_basics(): """Sanity check to verify that periodogram plotting works""" lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1) lc = lc.normalize() pg = lc.to_periodogram() pg.plot() plt.close() pg.plot(view='period') plt.close() pg.show_properties() pg.to_table() str(pg) def test_periodogram_normalization(): """Tests the normalization options""" lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1, flux_unit='electron/second') # Test amplitude normalization and correct units pg = lc.to_periodogram(normalization='amplitude') assert pg.power.unit == u.electron / u.second pg = lc.normalize(unit='ppm').to_periodogram(normalization='amplitude') assert pg.power.unit == u.cds.ppm # Test PSD normalization and correct units pg = lc.to_periodogram(freq_unit=u.microhertz, normalization='psd') assert pg.power.unit == (u.electron/u.second)**2 / u.microhertz pg = lc.normalize(unit='ppm').to_periodogram(freq_unit=u.microhertz, normalization='psd') assert pg.power.unit == u.cds.ppm**2 / u.microhertz def test_periodogram_warnings(): """Tests if warnings are raised for non-normalized periodogram input""" lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1) lc = lc.normalize(unit='ppm') # Test amplitude normalization and correct units pg = lc.to_periodogram(normalization='amplitude') assert pg.power.unit == u.cds.ppm pg = lc.to_periodogram(freq_unit=u.microhertz, normalization='psd') assert pg.power.unit == u.cds.ppm**2 / u.microhertz def test_periodogram_units(): """Tests whether periodogram has correct units""" # Fake, noisy data lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1, flux_unit='electron/second') p = lc.to_periodogram(normalization='amplitude') # Has units assert hasattr(p.frequency, 'unit') # Has the correct units assert p.frequency.unit == 1./u.day assert p.power.unit == u.electron / u.second assert p.period.unit == u.day assert p.frequency_at_max_power.unit == 1./u.day assert p.max_power.unit == u.electron / u.second def test_periodogram_can_find_periods(): """Periodogram should recover the correct period""" # Light curve that is noisy lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1) # Add a 100 day period signal lc.flux += np.sin((lc.time.value/float(lc.time.value.max())) * 20 * np.pi) lc = lc.normalize() p = lc.to_periodogram(normalization='amplitude') assert np.isclose(p.period_at_max_power.value, 100, rtol=1e-3) def test_periodogram_slicing(): """Tests whether periodograms can be sliced""" # Fake, noisy data lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1) lc = lc.normalize() p = lc.to_periodogram() assert len(p[0:200].frequency) == 200 # Test divide orig = p.power.sum() p /= 2 assert np.sum(p.power) == orig/2 # Test multiplication p *= 0 assert np.sum(p.power) == 0 # Test addition p += 100 assert np.all(p.power.value >= 100) # Test subtraction p -= 100 assert np.sum(p.power) == 0 def test_assign_periods(): """Test if you can assign periods and frequencies.""" lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000) + 0.1) periods = np.arange(1, 100) * u.day lc = lc.normalize() p = lc.to_periodogram(period=periods) # Get around the floating point error assert np.isclose(np.sum(periods - p.period).value, 0, rtol=1e-14) frequency = np.arange(1, 100) * u.Hz p = lc.to_periodogram(frequency=frequency) assert np.isclose(np.sum(frequency - p.frequency).value, 0, rtol=1e-14) def test_bin(): """Test if you can bin the periodogram.""" lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000) + 0.1) lc = lc.normalize() p = lc.to_periodogram() assert len(p.bin(binsize=10, method='mean').frequency) == len(p.frequency)//10 assert len(p.bin(binsize=10, method='median').frequency) == len(p.frequency)//10 def test_smooth(): """Test if you can smooth the periodogram and check any pitfalls """ np.random.seed(42) lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1) lc = lc.normalize() p = lc.to_periodogram(normalization='psd', freq_unit=u.microhertz) # Test boxkernel and logmedian methods assert all(p.smooth(method='boxkernel').frequency == p.frequency) assert all(p.smooth(method='logmedian').frequency == p.frequency) # Check output units assert p.smooth().power.unit == p.power.unit # Check logmedian smooth that the mean of the smoothed power should # be consistent with the mean of the power assert np.isclose(np.mean(p.smooth(method='logmedian').power.value), np.mean(p.power.value), atol=0.05*np.mean(p.power.value)) # Can't pass filter_width below 0. with pytest.raises(ValueError) as err: p.smooth(method='boxkernel', filter_width=-5.) # Can't pass a filter_width in the wrong units with pytest.raises(ValueError) as err: p.smooth(method='boxkernel', filter_width=5.*u.day) assert err.value.args[0] == 'the `filter_width` parameter must have frequency units.' # Can't (yet) use a periodogram with a non-evenly spaced frequencies with pytest.raises(ValueError) as err: p = np.arange(1, 100) p = lc.to_periodogram(period=p) p.smooth() # Check logmedian doesn't work if I give the filter width units with pytest.raises(ValueError) as err: p.smooth(method='logmedian', filter_width=5.*u.day) def test_flatten(): npts = 10000 np.random.seed(12069424) lc = LightCurve(time=np.arange(npts), flux=np.random.normal(1, 0.1, npts), flux_err=np.zeros(npts)+0.1) lc = lc.normalize() p = lc.to_periodogram(normalization='psd', freq_unit=1/u.day) # Check method returns equal frequency assert all(p.flatten(method='logmedian').frequency == p.frequency) assert all(p.flatten(method='boxkernel').frequency == p.frequency) # Check logmedian flatten of white noise returns mean of ~unity assert np.isclose(np.mean(p.flatten(method='logmedian').power.value), 1.0, atol=0.05) # Check return trend works s, b = p.flatten(return_trend=True) assert all(b.power == p.smooth(method='logmedian', filter_width=0.01).power) assert all(s.power == p.flatten().power) str(s) s.plot() plt.close() def test_index(): """Test if you can mask out periodogram """ lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1) lc = lc.normalize() p = lc.to_periodogram() mask = (p.frequency > 0.1*(1/u.day)) & (p.frequency < 0.2*(1/u.day)) assert len(p[mask].frequency) == mask.sum() def test_bls(caplog): ''' Test that BLS periodogram works and gives reasonable errors ''' lc = LightCurve(time=np.linspace(0, 10, 1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1) # should be able to make a periodogram p = lc.to_periodogram(method='bls') keys = ['period', 'power', 'duration', 'transit_time', 'depth', 'snr'] assert np.all([key in dir(p) for key in keys]) p.plot() plt.close() # we should be able to specify some keywords lc.to_periodogram(method='bls', minimum_period=0.2, duration=0.1, maximum_period=0.5) # Ridiculous BLS spectra should break. with pytest.raises(ValueError) as err: lc.to_periodogram(method='bls', frequency_factor=0.00001) assert err.value.args[0] == ('`period` contains over 72000001 points.Periodogram is too large to evaluate. Consider setting `frequency_factor` to a higher value.') # Some errors should occur p.compute_stats() for record in caplog.records: assert record.levelname == 'WARNING' assert len(caplog.records) == 3 assert 'No period specified.' in caplog.text # No more errors stats = p.compute_stats(1, 0.1, 0) assert len(caplog.records) == 3 assert isinstance(stats, dict) # Some errors should occur p.get_transit_model() for record in caplog.records: assert record.levelname == 'WARNING' assert len(caplog.records) == 6 assert 'No period specified.' in caplog.text model = p.get_transit_model(1, 0.1, 0) # No more errors assert len(caplog.records) == 6 # Model is LC assert isinstance(model, LightCurve) # Model is otherwise identical to LC assert np.in1d(model.time, lc.time).all() assert np.in1d(lc.time, model.time).all() mask = p.get_transit_mask(1, 0.1, 0) assert isinstance(mask, np.ndarray) assert isinstance(mask[0], np.bool_) assert mask.sum() > (~mask).sum() assert isinstance(p.period_at_max_power, u.Quantity) assert isinstance(p.duration_at_max_power, u.Quantity) assert isinstance(p.transit_time_at_max_power, Time) assert isinstance(p.depth_at_max_power, u.Quantity) def test_bls_period_recovery(): """Can BLS Periodogram recover the period of a synthetic light curve?""" # Planet parameters period = 2.0 transit_time = 0.5 duration = 0.1 depth = 0.2 flux_err = 0.01 # Create the synthetic light curve time = np.arange(0, 100, 0.1) flux = np.ones_like(time) transit_mask = np.abs((time-transit_time+0.5*period) % period-0.5*period) < 0.5*duration flux[transit_mask] = 1.0 - depth flux += flux_err * np.random.randn(len(time)) synthetic_lc = LightCurve(time=time, flux=flux) # Can BLS recover the period? bls_period = synthetic_lc.to_periodogram("bls").period_at_max_power assert_almost_equal(bls_period.value, period, decimal=2) # Does it work if we inject a sneaky NaN? synthetic_lc.flux[10] = np.nan bls_period = synthetic_lc.to_periodogram("bls").period_at_max_power assert_almost_equal(bls_period.value, period, decimal=2) # Does it work if all errors are NaNs? # This is a regression test for issue #428 synthetic_lc.flux_err = np.array([np.nan] * len(time)) assert_almost_equal(bls_period.value, period, decimal=2) def test_error_messages(): """Test periodogram raises reasonable errors """ # Fake, noisy data lc = LightCurve(time=np.arange(1000), flux=np.random.normal(1, 0.1, 1000), flux_err=np.zeros(1000)+0.1) # Can't specify period range and frequency range with pytest.raises(ValueError) as err: lc.to_periodogram(maximum_frequency=0.1, minimum_period=10) # Can't have a minimum frequency > maximum frequency with pytest.raises(ValueError) as err: lc.to_periodogram(maximum_frequency=0.1, minimum_frequency=10) assert err.value.args[0] == 'minimum_frequency cannot be larger than maximum_frequency' # Can't have a minimum period > maximum period with pytest.raises(ValueError) as err: lc.to_periodogram(maximum_period=0.1, minimum_period=10) assert err.value.args[0] == 'minimum_period cannot be larger than maximum_period' # Can't specify periods and frequencies with pytest.raises(ValueError) as err: lc.to_periodogram(frequency=np.arange(10), period=np.arange(10)) # Don't accept NaNs with pytest.raises(ValueError) as err: lc_with_nans = lc.copy() lc_with_nans.flux[0] = np.nan lc_with_nans.to_periodogram() assert('Lightcurve contains NaN values.' in err.value.args[0]) # No unitless periodograms with pytest.raises(ValueError) as err: Periodogram([0], [1]) assert err.value.args[0] == 'frequency must be an `astropy.units.Quantity` object.' # No unitless periodograms with pytest.raises(ValueError) as err: Periodogram([0]*u.Hz, [1]) assert err.value.args[0] == 'power must be an `astropy.units.Quantity` object.' # No single value periodograms with pytest.raises(ValueError) as err: Periodogram([0]*u.Hz, [1]*u.K) assert err.value.args[0] == 'frequency and power must have a length greater than 1.' # No uneven arrays with pytest.raises(ValueError) as err: Periodogram([0, 1, 2, 3]*u.Hz, [1, 1]*u.K) assert err.value.args[0] == 'frequency and power must have the same length.' # Bad frequency units with pytest.raises(ValueError) as err: Periodogram([0, 1, 2]*u.K, [1, 1, 1]*u.K) assert err.value.args[0] == 'Frequency must be in units of 1/time.' # Bad binning with pytest.raises(ValueError) as err: Periodogram([0, 1, 2]*u.Hz, [1, 1, 1]*u.K).bin(binsize=-2) assert err.value.args[0] == 'binsize must be larger than or equal to 1' # Bad binning method with pytest.raises(ValueError) as err: Periodogram([0, 1, 2]*u.Hz, [1, 1, 1]*u.K).bin(method='not-implemented') assert("method 'not-implemented' is not supported" in err.value.args[0]) # Bad smooth method with pytest.raises(ValueError) as err: Periodogram([0, 1, 2]*u.Hz, [1, 1, 1]*u.K).smooth(method="not-implemented") assert("method 'not-implemented' is not supported" in err.value.args[0]) def test_bls_period(): """Regression test for #514.""" lc = LightCurve(time=[1, 2, 3], flux=[4, 5, 6]) period = [1, 2, 3, 4, 5] pg = lc.to_periodogram(method="bls", period=period) assert_array_equal(pg.period.value, period) with pytest.raises(ValueError) as err: # NaNs should raise a nice error message lc.to_periodogram(method="bls", period=[1, 2, 3, np.nan, 4]) assert("period" in err.value.args[0])