#!/usr/bin/env python """Basic test of the Calc module on 2D data.""" import datetime from os.path import isfile import shutil import unittest import pytest import itertools import cftime import numpy as np import xarray as xr from aospy import Var from aospy.calc import Calc, _add_metadata_as_attrs, _replace_pressure from aospy.internal_names import ETA_STR from aospy.utils.vertcoord import p_eta, dp_eta, p_level, dp_level from .data.objects.examples import ( example_proj, example_model, example_run, var_not_time_defined, condensation_rain, convection_rain, precip, sphum, globe, sahel, p, dp ) def _test_output_attrs(calc, dtype_out): with xr.open_dataset(calc.path_out[dtype_out]) as data: expected_units = calc.var.units if calc.dtype_out_vert == 'vert_int': if expected_units != '': expected_units = ("(vertical integral of {0}):" " {0} m)").format(expected_units) else: expected_units = ("(vertical integral of quantity" " with unspecified units)") expected_description = calc.var.description for name, arr in data.data_vars.items(): assert expected_units == arr.attrs['units'] assert expected_description == arr.attrs['description'] def _clean_test_direcs(): for direc in [example_proj.direc_out, example_proj.tar_direc_out]: try: shutil.rmtree(direc) except OSError: pass def _test_files_and_attrs(calc, dtype_out): assert isfile(calc.path_out[dtype_out]) assert isfile(calc.path_tar_out) _test_output_attrs(calc, dtype_out) _2D_DATE_RANGES = { 'datetime': (datetime.datetime(4, 1, 1), datetime.datetime(6, 12, 31)), 'datetime64': (np.datetime64('0004-01-01'), np.datetime64('0006-12-31')), 'cftime': (cftime.DatetimeNoLeap(4, 1, 1), cftime.DatetimeNoLeap(6, 12, 31)), 'str': ('0004', '0006') } _3D_DATE_RANGES = { 'datetime': (datetime.datetime(6, 1, 1), datetime.datetime(6, 1, 31)), 'datetime64': (np.datetime64('0006-01-01'), np.datetime64('0006-01-31')), 'cftime': (cftime.DatetimeNoLeap(6, 1, 1), cftime.DatetimeNoLeap(6, 1, 31)), 'str': ('0006', '0006') } _2D_VARS = {'basic': condensation_rain, 'composite': precip} _2D_DTYPE_OUT_VERT = {'None': None} _2D_DTYPE_IN_VERT = {'None': None} _3D_VARS = {'3D': sphum} _3D_DTYPE_OUT_VERT = {'vert_int': 'vert_int', 'vert_av': 'vert_av'} _3D_DTYPE_IN_VERT = {'sigma': 'sigma'} _CASES = ( list(itertools.product(_2D_DATE_RANGES.items(), _2D_VARS.items(), _2D_DTYPE_IN_VERT.items(), _2D_DTYPE_OUT_VERT.items())) + list(itertools.product(_3D_DATE_RANGES.items(), _3D_VARS.items(), _3D_DTYPE_IN_VERT.items(), _3D_DTYPE_OUT_VERT.items())) ) _CALC_TESTS = {} for ((date_type, date_range), (test_type, var), (vert_in_label, vert_in), (vert_out_label, vert_out)) in _CASES: _CALC_TESTS['{}-{}-{}-{}'.format( date_type, test_type, vert_in_label, vert_out_label)] = ( date_range, var, vert_in, vert_out) @pytest.fixture(params=_CALC_TESTS.values(), ids=list(_CALC_TESTS.keys())) def test_params(request): date_range, var, vert_in, vert_out = request.param yield { 'proj': example_proj, 'model': example_model, 'run': example_run, 'var': var, 'date_range': date_range, 'intvl_in': 'monthly', 'dtype_in_time': 'ts', 'dtype_in_vert': vert_in, 'dtype_out_vert': vert_out } _clean_test_direcs() def test_annual_mean(test_params): calc = Calc(intvl_out='ann', dtype_out_time='av', **test_params) calc.compute() _test_files_and_attrs(calc, 'av') def test_annual_ts(test_params): calc = Calc(intvl_out='ann', dtype_out_time='ts', **test_params) calc.compute() _test_files_and_attrs(calc, 'ts') def test_seasonal_mean(test_params): calc = Calc(intvl_out='djf', dtype_out_time='av', **test_params) calc.compute() _test_files_and_attrs(calc, 'av') def test_seasonal_ts(test_params): calc = Calc(intvl_out='djf', dtype_out_time='ts', **test_params) calc.compute() _test_files_and_attrs(calc, 'ts') def test_monthly_mean(test_params): calc = Calc(intvl_out=1, dtype_out_time='av', **test_params) calc.compute() _test_files_and_attrs(calc, 'av') def test_monthly_ts(test_params): calc = Calc(intvl_out=1, dtype_out_time='ts', **test_params) calc.compute() _test_files_and_attrs(calc, 'ts') def test_simple_reg_av(test_params): calc = Calc(intvl_out='ann', dtype_out_time='reg.av', region=[globe], **test_params) calc.compute() _test_files_and_attrs(calc, 'reg.av') def test_simple_reg_ts(test_params): calc = Calc(intvl_out='ann', dtype_out_time='reg.ts', region=[globe], **test_params) calc.compute() _test_files_and_attrs(calc, 'reg.ts') @pytest.mark.filterwarnings('ignore:Mean of empty slice') def test_complex_reg_av(test_params): calc = Calc(intvl_out='ann', dtype_out_time='reg.av', region=[sahel], **test_params) calc.compute() _test_files_and_attrs(calc, 'reg.av') test_params_not_time_defined = { 'proj': example_proj, 'model': example_model, 'run': example_run, 'var': var_not_time_defined, 'date_range': 'default', 'intvl_in': 'monthly', 'dtype_in_time': 'av', 'intvl_out': 1, } @pytest.mark.parametrize('dtype_out_time', [None, []]) def test_calc_object_no_time_options(dtype_out_time): test_params_not_time_defined['dtype_out_time'] = dtype_out_time calc = Calc(**test_params_not_time_defined) if isinstance(dtype_out_time, list): assert calc.dtype_out_time == tuple(dtype_out_time) else: assert calc.dtype_out_time == tuple([dtype_out_time]) @pytest.mark.parametrize( 'dtype_out_time', ['av', 'std', 'ts', 'reg.av', 'reg.std', 'reg.ts']) def test_calc_object_string_time_options(dtype_out_time): test_params_not_time_defined['dtype_out_time'] = dtype_out_time with pytest.raises(ValueError): Calc(**test_params_not_time_defined) def test_calc_object_time_options(): time_options = ['av', 'std', 'ts', 'reg.av', 'reg.std', 'reg.ts'] for i in range(1, len(time_options) + 1): for time_option in list(itertools.permutations(time_options, i)): if time_option != ('None',): test_params_not_time_defined['dtype_out_time'] = time_option with pytest.raises(ValueError): Calc(**test_params_not_time_defined) @pytest.mark.parametrize( ('units', 'description', 'dtype_out_vert', 'expected_units', 'expected_description'), [('', '', None, '', ''), ('m', '', None, 'm', ''), ('', 'rain', None, '', 'rain'), ('m', 'rain', None, 'm', 'rain'), ('', '', 'vert_av', '', ''), ('m', '', 'vert_av', 'm', ''), ('', 'rain', 'vert_av', '', 'rain'), ('m', 'rain', 'vert_av', 'm', 'rain'), ('', '', 'vert_int', '(vertical integral of quantity with unspecified units)', ''), ('m', '', 'vert_int', '(vertical integral of m): m kg m^-2)', ''), ('', 'rain', 'vert_int', '(vertical integral of quantity with unspecified units)', 'rain'), ('m', 'rain', 'vert_int', '(vertical integral of m): m kg m^-2)', 'rain')]) def test_attrs(units, description, dtype_out_vert, expected_units, expected_description): da = xr.DataArray(None) ds = xr.Dataset({'bar': 'foo', 'boo': 'baz'}) da = _add_metadata_as_attrs(da, units, description, dtype_out_vert) ds = _add_metadata_as_attrs(ds, units, description, dtype_out_vert) assert expected_units == da.attrs['units'] assert expected_description == da.attrs['description'] for name, arr in ds.data_vars.items(): assert expected_units == arr.attrs['units'] assert expected_description == arr.attrs['description'] @pytest.fixture() def recursive_test_params(): basic_params = { 'proj': example_proj, 'model': example_model, 'run': example_run, 'var': condensation_rain, 'date_range': (datetime.datetime(4, 1, 1), datetime.datetime(6, 12, 31)), 'intvl_in': 'monthly', 'dtype_in_time': 'ts' } recursive_params = basic_params.copy() recursive_condensation_rain = Var( name='recursive_condensation_rain', variables=(precip, convection_rain), func=lambda x, y: x - y, def_time=True) recursive_params['var'] = recursive_condensation_rain yield (basic_params, recursive_params) _clean_test_direcs() def test_recursive_calculation(recursive_test_params): basic_params, recursive_params = recursive_test_params calc = Calc(intvl_out='ann', dtype_out_time='av', **basic_params) calc = calc.compute() expected = xr.open_dataset( calc.path_out['av'])['condensation_rain'] _test_files_and_attrs(calc, 'av') calc = Calc(intvl_out='ann', dtype_out_time='av', **recursive_params) calc = calc.compute() result = xr.open_dataset( calc.path_out['av'])['recursive_condensation_rain'] _test_files_and_attrs(calc, 'av') xr.testing.assert_equal(expected, result) def test_compute_pressure(): calc = Calc( intvl_out='ann', dtype_out_time='av', var=p, proj=example_proj, model=example_model, run=example_run, date_range=('0006', '0006'), intvl_in='monthly', dtype_in_time='ts', dtype_in_vert='sigma', dtype_out_vert=None ) calc.compute() _test_files_and_attrs(calc, 'av') _clean_test_direcs() def test_compute_pressure_thicknesses(): calc = Calc( intvl_out='ann', dtype_out_time='av', var=dp, proj=example_proj, model=example_model, run=example_run, date_range=('0006', '0006'), intvl_in='monthly', dtype_in_time='ts', dtype_in_vert='sigma', dtype_out_vert=None ) calc.compute() _test_files_and_attrs(calc, 'av') _clean_test_direcs() @pytest.mark.parametrize( ['dtype_in_vert', 'expected'], [(ETA_STR, [p_eta, dp_eta, condensation_rain, 5]), ('pressure', [p_level, dp_level, condensation_rain, 5])]) def test_replace_pressure(dtype_in_vert, expected): arguments = [p, dp, condensation_rain, 5] p_in, dp_in, cond, num = arguments p_expected, dp_expected, cond_expected, num_expected = expected assert p_in.func != p_expected.func assert dp_in.func != dp_expected.func result = _replace_pressure(arguments, dtype_in_vert) p_out, dp_out, cond_out, num_out = result assert p_out.func == p_expected.func assert dp_out.func == dp_expected.func assert cond_out.func == cond_expected.func assert num_out == num_expected if __name__ == '__main__': unittest.main()