import pytest import pandas as pd import numpy as np import itertools from vivarium.interpolation import Interpolation, validate_parameters, check_data_complete, Order0Interp def make_bin_edges(data: pd.DataFrame, col: str) -> pd.DataFrame: """ Given a dataframe and a column containing midpoints, construct equally sized bins around midpoints. """ mid_pts = data[[col]].drop_duplicates().sort_values(by=col).reset_index(drop=True) mid_pts['shift'] = mid_pts[col].shift() mid_pts['left'] = mid_pts.apply(lambda row: (row[col] if pd.isna(row['shift']) else 0.5 * (row[col] + row['shift'])), axis=1) mid_pts['right'] = mid_pts['left'].shift(-1) mid_pts['right'] = mid_pts.right.fillna(mid_pts.right.max() + mid_pts.left.tolist()[-1] - mid_pts.left.tolist()[-2]) data = data.copy() idx = data.index data = data.set_index(col, drop=False) mid_pts = mid_pts.set_index(col, drop=False) data[[col, f'{col}_left', f'{col}_right']] = mid_pts[[col, 'left', 'right']] return data.set_index(idx) @pytest.mark.skip(reason="only order 0 interpolation currently supported") def test_1d_interpolation(): df = pd.DataFrame({'a': np.arange(100), 'b': np.arange(100), 'c': np.arange(100, 0, -1)}) df = df.sample(frac=1) # Shuffle table to assure interpolation works given unsorted input i = Interpolation(df, (), ('a',), 1, True) query = pd.DataFrame({'a': np.arange(100, step=0.01)}) assert np.allclose(query.a, i(query).b) assert np.allclose(100-query.a, i(query).c) @pytest.mark.skip(reason="only order 0 interpolation currently supported") def test_age_year_interpolation(): years = list(range(1990, 2010)) ages = list(range(0, 90)) pops = np.array(ages)*11.1 data = [] for age, pop in zip(ages, pops): for year in years: for sex in ['Male', 'Female']: data.append({'age': age, 'sex': sex, 'year': year, 'pop': pop}) df = pd.DataFrame(data) df = df.sample(frac=1) # Shuffle table to assure interpolation works given unsorted input i = Interpolation(df, ('sex', 'age'), ('year',), 1, True) query = pd.DataFrame({'year': [1990, 1990], 'age': [35, 35], 'sex': ['Male', 'Female']}) assert np.allclose(i(query), 388.5) @pytest.mark.skip(reason="only order 0 interpolation currently supported") def test_interpolation_called_missing_key_col(): a = [range(1990, 1995), range(25, 30), ['Male', 'Female']] df = pd.DataFrame(list(itertools.product(*a)), columns=['year', 'age', 'sex']) df['pop'] = df.age * 11.1 df = df.sample(frac=1) # Shuffle table to assure interpolation works given unsorted input i = Interpolation(df, ['sex',], ['year','age'], 1, True) query = pd.DataFrame({'year': [1990, 1990], 'age': [35, 35]}) with pytest.raises(ValueError): i(query) @pytest.mark.skip(reason="only order 0 interpolation currently supported") def test_interpolation_called_missing_param_col(): a = [range(1990, 1995), range(25, 30), ['Male', 'Female']] df = pd.DataFrame(list(itertools.product(*a)), columns=['year', 'age', 'sex']) df['pop'] = df.age * 11.1 df = df.sample(frac=1) # Shuffle table to assure interpolation works given unsorted input i = Interpolation(df, ['sex',], ['year','age'], 1, True) query = pd.DataFrame({'year': [1990, 1990], 'sex': ['Male', 'Female']}) with pytest.raises(ValueError): i(query) @pytest.mark.skip(reason="only order 0 interpolation currently supported") def test_2d_interpolation(): a = np.mgrid[0:5, 0:5][0].reshape(25) b = np.mgrid[0:5, 0:5][1].reshape(25) df = pd.DataFrame({'a': a, 'b': b, 'c': b, 'd': a}) df = df.sample(frac=1) # Shuffle table to assure interpolation works given unsorted input i = Interpolation(df, (), ('a', 'b'), 1, True) query = pd.DataFrame({'a': np.arange(4, step=0.01), 'b': np.arange(4, step=0.01)}) assert np.allclose(query.b, i(query).c) assert np.allclose(query.a, i(query).d) @pytest.mark.skip(reason="only order 0 interpolation currently supported") def test_interpolation_with_categorical_parameters(): a = ['one']*100 + ['two']*100 b = np.append(np.arange(100), np.arange(100)) c = np.append(np.arange(100), np.arange(100, 0, -1)) df = pd.DataFrame({'a': a, 'b': b, 'c': c}) df = df.sample(frac=1) # Shuffle table to assure interpolation works given unsorted input i = Interpolation(df, ('a',), ('b',), 1, True) query_one = pd.DataFrame({'a': 'one', 'b': np.arange(100, step=0.01)}) query_two = pd.DataFrame({'a': 'two', 'b': np.arange(100, step=0.01)}) assert np.allclose(np.arange(100, step=0.01), i(query_one).c) assert np.allclose(np.arange(100, 0, step=-0.01), i(query_two).c) def test_order_zero_2d(): a = np.mgrid[0:5, 0:5][0].reshape(25) b = np.mgrid[0:5, 0:5][1].reshape(25) df = pd.DataFrame({'a': a + 0.5, 'b': b + 0.5, 'c': b*3, 'garbage': ['test']*len(a)}) df = make_bin_edges(df, 'a') df = make_bin_edges(df, 'b') df = df.sample(frac=1) # Shuffle table to assure interpolation works given unsorted input i = Interpolation(df, ('garbage',), [('a', 'a_left', 'a_right'), ('b', 'b_left', 'b_right')], order=0, extrapolate=True) column = np.arange(0.5, 4, step=0.011) query = pd.DataFrame({'a': column, 'b': column, 'garbage': ['test']*(len(column))}) assert np.allclose(query.b.astype(int) * 3, i(query).c) def test_order_zero_2d_fails_on_extrapolation(): a = np.mgrid[0:5, 0:5][0].reshape(25) b = np.mgrid[0:5, 0:5][1].reshape(25) df = pd.DataFrame({'a': a + 0.5, 'b': b + 0.5, 'c': b*3, 'garbage': ['test']*len(a)}) df = make_bin_edges(df, 'a') df = make_bin_edges(df, 'b') df = df.sample(frac=1) # Shuffle table to assure interpolation works given unsorted input i = Interpolation(df, ('garbage',), [('a', 'a_left', 'a_right'), ('b', 'b_left', 'b_right')], order=0, extrapolate=False) column = np.arange(4, step=0.011) query = pd.DataFrame({'a': column, 'b': column, 'garbage': ['test']*(len(column))}) with pytest.raises(ValueError) as error: i(query) message = error.value.args[0] assert 'Extrapolation' in message and 'a' in message def test_order_zero_1d_no_extrapolation(): s = pd.Series({0: 0, 1: 1}).reset_index() s = make_bin_edges(s, 'index') f = Interpolation(s, tuple(), [['index', 'index_left', 'index_right']], order=0, extrapolate=False) assert f(pd.DataFrame({'index': [0]}))[0][0] == 0, 'should be precise at index values' assert f(pd.DataFrame({'index': [0.999]}))[0][0] == 1 with pytest.raises(ValueError) as error: f(pd.DataFrame({'index': [1]})) message = error.value.args[0] assert 'Extrapolation' in message and 'index' in message def test_order_zero_1d_constant_extrapolation(): s = pd.Series({0: 0, 1: 1}).reset_index() s = make_bin_edges(s, 'index') f = Interpolation(s, tuple(), [['index', 'index_left', 'index_right']], order=0, extrapolate=True) assert f(pd.DataFrame({'index': [1]}))[0][0] == 1 assert f(pd.DataFrame({'index': [2]}))[0][0] == 1, 'should be constant extrapolation outside of input range' assert f(pd.DataFrame({'index': [-1]}))[0][0] == 0 def test_validate_parameters__empty_data(): with pytest.raises(ValueError) as error: validate_parameters(pd.DataFrame(columns=["age_left", "age_right", "sex", "year_left", "year_right", "value"]), ["sex"], [("age", "age_left", "age_right"), ["year", "year_left", "year_right"]]) message = error.value.args[0] assert 'empty' in message def test_check_data_complete_gaps(): data = pd.DataFrame({'year_start': [1990, 1990, 1995, 1995], 'year_end': [1995, 1995, 2000, 2000], 'age_start': [16, 10, 10, 16], 'age_end': [20, 15, 15, 20],}) with pytest.raises(NotImplementedError) as error: check_data_complete(data, [('year', 'year_start', 'year_end'), ['age', 'age_start', 'age_end']]) message = error.value.args[0] assert "age_start" in message and "age_end" in message def test_check_data_complete_overlap(): data = pd.DataFrame({'year_start': [1995, 1995, 2000, 2005, 2010], 'year_end': [2000, 2000, 2005, 2010, 2015]}) with pytest.raises(ValueError) as error: check_data_complete(data, [('year', 'year_start', 'year_end')]) message = error.value.args[0] assert "year_start" in message and "year_end" in message def test_check_data_missing_combos(): data = pd.DataFrame({'year_start': [1990, 1990, 1995], 'year_end': [1995, 1995, 2000], 'age_start': [10, 15, 10], 'age_end': [15, 20, 15]}) with pytest.raises(ValueError) as error: check_data_complete(data, [['year', 'year_start', 'year_end'], ('age', 'age_start', 'age_end')]) message = error.value.args[0] assert 'combination' in message def test_order0interp(): data = pd.DataFrame({'year_start': [1990, 1990, 1990, 1990, 1995, 1995, 1995, 1995], 'year_end': [1995, 1995, 1995, 1995, 2000, 2000, 2000, 2000], 'age_start': [15, 10, 10, 15, 10, 10, 15, 15], 'age_end': [20, 15, 15, 20, 15, 15, 20, 20], 'height_start': [140, 160, 140, 160, 140, 160, 140, 160], 'height_end': [160, 180, 160, 180, 160, 180, 160, 180], 'value': [5, 3, 1, 7, 8, 6, 4, 2]}) interp = Order0Interp(data, [('age', 'age_start', 'age_end'), ('year', 'year_start', 'year_end'), ('height', 'height_start', 'height_end'),] , ['value'], True) interpolants = pd.DataFrame({'age': [12, 17, 8, 24, 12], 'year': [1992, 1998, 1985, 1992, 1992], 'height': [160, 145, 140, 179, 160]}) result = interp(interpolants) assert result.equals(pd.DataFrame({'value': [3, 4, 1, 7, 3]})) def test_order_zero_1d_with_key_column(): data = pd.DataFrame({'year_start': [1990, 1990, 1995, 1995], 'year_end': [1995, 1995, 2000, 2000], 'sex': ['Male', 'Female', 'Male', 'Female'], 'value_1': [10, 7, 2, 12], 'value_2': [1200, 1350, 1476, 1046]}) i = Interpolation(data, ['sex',], [('year', 'year_start', 'year_end'),], 0, True) query = pd.DataFrame({'year': [1992, 1993,], 'sex': ['Male', 'Female']}) expected_result = pd.DataFrame({'value_1': [10.0, 7.0], 'value_2': [1200.0, 1350.0]}) assert i(query).equals(expected_result) def test_order_zero_non_numeric_values(): data = pd.DataFrame({'year_start': [1990, 1990], 'year_end': [1995, 1995], 'age_start': [15, 24,], 'age_end': [24, 30], 'value_1': ['blue', 'red']}) i = Interpolation(data, tuple(), [('year', 'year_start', 'year_end'), ('age', 'age_start', 'age_end')], 0, True) query = pd.DataFrame({'year': [1990, 1990], 'age': [15, 24,]}, index=[1, 0]) expected_result = pd.DataFrame({'value_1': ['blue', 'red']}, index=[1, 0]) assert i(query).equals(expected_result) def test_order_zero_3d_with_key_col(): data = pd.DataFrame({'year_start': [1990, 1990, 1990, 1990, 1995, 1995, 1995, 1995]*2, 'year_end': [1995, 1995, 1995, 1995, 2000, 2000, 2000, 2000]*2, 'age_start': [15, 10, 10, 15, 10, 10, 15, 15]*2, 'age_end': [20, 15, 15, 20, 15, 15, 20, 20]*2, 'height_start': [140, 160, 140, 160, 140, 160, 140, 160]*2, 'height_end': [160, 180, 160, 180, 160, 180, 160, 180]*2, 'sex': ['Male']*8+['Female']*8, 'value': [5, 3, 1, 7, 8, 6, 4, 2, 6, 4, 2, 8, 9, 7, 5, 3]}) interp = Interpolation(data, ('sex',), [('age', 'age_start', 'age_end'), ('year', 'year_start', 'year_end'), ('height', 'height_start', 'height_end')], 0, True) interpolants = pd.DataFrame({'age': [12, 17, 8, 24, 12], 'year': [1992, 1998, 1985, 1992, 1992], 'height': [160, 145, 140, 185, 160], 'sex': ['Male', 'Female', 'Female', 'Male', 'Male']}, index=[10, 4, 7, 0, 9]) result = interp(interpolants) assert result.equals(pd.DataFrame({'value': [3.0, 5.0, 2.0, 7.0, 3.0]}, index=[10, 4, 7, 0, 9])) def test_order_zero_diff_bin_sizes(): data = pd.DataFrame({'year_start': [1990, 1995, 1996, 2005, 2005.5,], 'year_end': [1995, 1996, 2005, 2005.5, 2010], 'value': [1, 5, 2.3, 6, 100]}) i = Interpolation(data, tuple(), [('year', 'year_start', 'year_end')], 0, False) query = pd.DataFrame({'year': [2007, 1990, 2005.4, 1994, 2004, 1995, 2002, 1995.5, 1996]}) expected_result = pd.DataFrame({'value': [100, 1, 6, 1, 2.3, 5, 2.3, 5, 2.3]}) assert i(query).equals(expected_result) def test_order_zero_given_call_column(): data = pd.DataFrame({'year_start': [1990, 1995, 1996, 2005, 2005.5,], 'year_end': [1995, 1996, 2005, 2005.5, 2010], 'year': [1992.5, 1995.5, 2000, 2005.25, 2007.75], 'value': [1, 5, 2.3, 6, 100]}) i = Interpolation(data, tuple(), [('year', 'year_start', 'year_end')], 0, False) query = pd.DataFrame({'year': [2007, 1990, 2005.4, 1994, 2004, 1995, 2002, 1995.5, 1996]}) expected_result = pd.DataFrame({'value': [100, 1, 6, 1, 2.3, 5, 2.3, 5, 2.3]}) assert i(query).equals(expected_result)