import unittest from tests.TestSupport.Info import resource_path import EXOSIMS import EXOSIMS.Prototypes.ZodiacalLight import EXOSIMS.ZodiacalLight from EXOSIMS import MissionSim import astropy.units as u import pkgutil from EXOSIMS.util.get_module import get_module import numpy as np import os, json from tests.TestSupport.Utilities import RedirectStreams import sys # Python 3 compatibility: if sys.version_info[0] > 2: from io import StringIO else: from StringIO import StringIO """ZodiacalLight module unit tests based on previous tests by Paul Nunez, JPL""" class TestZodiacalLight(unittest.TestCase): """ Global ZodiacalLight tests. Applied to all implementations, for overloaded methods only. Any implementation-specific methods, or to test specific new method functionality, separate tests are needed. """ def setUp(self): self.dev_null = open(os.devnull, 'w') self.script = resource_path('test-scripts/template_prototype_testing.json') with open(self.script) as f: self.spec = json.loads(f.read()) with RedirectStreams(stdout=self.dev_null): self.sim = MissionSim.MissionSim(self.script) self.TL = self.sim.TargetList self.nStars = self.TL.nStars self.star_index = np.array(range(0, self.nStars)) self.Obs = self.sim.Observatory self.mode = self.sim.OpticalSystem.observingModes[0] self.TK = self.sim.TimeKeeping assert self.nStars > 10, "Need at least 10 stars in the target list for the unit test." self.unit = 1./u.arcsec**2 modtype = getattr(EXOSIMS.Prototypes.ZodiacalLight.ZodiacalLight, '_modtype') pkg = EXOSIMS.ZodiacalLight self.allmods = [get_module(modtype)] for loader, module_name, is_pkg in pkgutil.walk_packages(pkg.__path__, pkg.__name__ + '.'): if not is_pkg: mod = get_module(module_name.split('.')[-1], modtype) self.assertTrue(mod._modtype is modtype, '_modtype mismatch for %s' % mod.__name__) self.allmods.append(mod) def test_fZ(self): """ Test that fZ returns correct type and units. """ for mod in self.allmods: if 'fZ' in mod.__dict__: obj = mod() fZs = obj.fZ(self.Obs, self.TL, self.star_index, self.TK.currentTimeAbs, self.mode) self.assertEqual(len(fZs), self.nStars,'fZ does not return same number of values as nStars for {}'.format(mod.__name__)) self.assertEqual(fZs.unit, self.unit, 'fZ does not return 1/arcsec**2 for {}'.format(mod.__name__)) def test_fEZ(self): """ Test that fEZ returns correct shape and units. """ exclude_mods=[] for mod in self.allmods: if mod.__name__ in exclude_mods: continue if 'fEZ' in mod.__dict__: obj = mod() # use 3 planets d = 10.*np.random.rand(3)*u.AU I = np.random.uniform(0.0, 180.0, 3)*u.deg fEZs = obj.fEZ(self.TL.MV[0], I, d) self.assertEqual(len(fEZs), 3, 'fEZ does not return same number of values as planets tested for {}'.format(mod.__name__)) self.assertEqual(fEZs.unit, self.unit, 'fEZ does not return 1/arcsec**2 for {}'.format(mod.__name__)) def test_generate_fZ(self): """ Test generate fZ method """ for mod in self.allmods: if 'generate_fZ' in mod.__dict__: with RedirectStreams(stdout=self.dev_null): obj = mod() #Check if File Exists and if it does, delete it #if os.path.isfile(self.sim.SurveySimulation.cachefname+'starkfZ'): # os.remove(self.sim.SurveySimulation.cachefname+'starkfZ') OS = self.sim.OpticalSystem allModes = OS.observingModes mode = list(filter(lambda mode: mode['detectionMode'] == True, allModes))[0] hashname = self.sim.SurveySimulation.cachefname self.sim.ZodiacalLight.fZ_startSaved = obj.generate_fZ(self.Obs, self.TL, self.TK, mode, hashname) self.assertEqual(self.sim.ZodiacalLight.fZ_startSaved.shape[0],self.nStars) #Should also check length of fZ_startSaved?? self.assertEqual(self.sim.ZodiacalLight.fZ_startSaved.shape[1],1000)#This was arbitrarily selected. def test_calcfZmax(self): """ Test calcfZmax method """ for mod in self.allmods: if 'calcfZmax' in mod.__dict__: with RedirectStreams(stdout=self.dev_null): obj = mod() #Check if File Exists and if it does, delete it if os.path.isfile(self.sim.SurveySimulation.cachefname+'fZmax'): os.remove(self.sim.SurveySimulation.cachefname+'fZmax') sInds = np.asarray([0]) currentTimeAbs = self.sim.TimeKeeping.currentTimeAbs OS = self.sim.OpticalSystem allModes = OS.observingModes mode = list(filter(lambda mode: mode['detectionMode'] == True, allModes))[0] hashname = self.sim.SurveySimulation.cachefname self.sim.ZodiacalLight.fZ_startSaved = obj.generate_fZ(self.Obs, self.TL, self.TK, mode, hashname) valfZmax = np.zeros(sInds.shape[0]) timefZmax = np.zeros(sInds.shape[0]) [valfZmax, timefZmax] = obj.calcfZmax(sInds, self.Obs, self.TL, self.TK, mode, hashname) self.assertTrue(len(valfZmax) == len(sInds)) self.assertTrue(len(timefZmax) == len(sInds)) self.assertTrue(valfZmax[0].unit == self.unit) self.assertTrue(timefZmax[0].format == currentTimeAbs.format) def test_calcfZmin(self): """ Test calcfZmin method """ for mod in self.allmods: if 'calcfZmin' in mod.__dict__: with RedirectStreams(stdout=self.dev_null): obj = mod() sInds = np.asarray([0]) currentTimeAbs = self.TK.currentTimeAbs OS = self.sim.OpticalSystem allModes = OS.observingModes mode = list(filter(lambda mode: mode['detectionMode'] == True, allModes))[0] hashname = self.sim.SurveySimulation.cachefname self.sim.ZodiacalLight.fZ_startSaved = obj.generate_fZ(self.Obs, self.TL, self.TK, mode, hashname) fZQuads = obj.calcfZmin(sInds, self.Obs, self.TL, self.TK, mode, hashname) [valfZmin, timefZmin] = obj.extractfZmin_fZQuads(fZQuads) self.assertTrue(len(valfZmin) == len(sInds)) self.assertTrue(len(timefZmin) == len(sInds)) self.assertTrue(valfZmin[0].unit == 1/u.arcsec**2) self.assertTrue(timefZmin[0].format == currentTimeAbs.format) def test_str(self): """ Test __str__ method, for full coverage and check that all modules have required attributes. """ atts_list = ['magZ', 'magEZ', 'varEZ', 'fZ0', 'fEZ0'] exclude_mods=[] for mod in self.allmods: if mod.__name__ in exclude_mods: continue with RedirectStreams(stdout=self.dev_null): obj = mod(**self.spec) original_stdout = sys.stdout sys.stdout = StringIO() # call __str__ method result = obj.__str__() # examine what was printed contents = sys.stdout.getvalue() self.assertEqual(type(contents), type('')) # attributes from ICD for att in atts_list: self.assertIn(att,contents,'{} missing for {}'.format(att,mod.__name__)) sys.stdout.close() # it also returns a string, which is not necessary self.assertEqual(type(result), type('')) # put stdout back sys.stdout = original_stdout