import pytest import momi from momi.demo_model import DemographicModel from momi.events import get_event_from_old, LeafEvent, SizeEvent, JoinEvent, PulseEvent, GrowthEvent from demo_utils import simple_admixture_demo, simple_two_pop_demo, piecewise_constant_demo, simple_five_pop_demo, simple_five_pop_demo, exp_growth_model, exp_growth_0_model import autograd.numpy as np import sys import os import pickle as pickle # TODO add a test with archaic leafs MODELS = [{'demo': simple_admixture_demo, 'nlins': (5, 5), 'params': 7}, {'demo': simple_two_pop_demo, 'nlins': (5, 8), 'params': 4}, {'demo': piecewise_constant_demo, 'nlins': (10,), 'params': 9}, {'demo': simple_five_pop_demo, 'nlins': tuple( range(1, 6)), 'params': 30}, {'demo': exp_growth_model, 'nlins': (10,), 'params': 3}, {'demo': exp_growth_0_model, 'nlins': (10,), 'params': 2}, ] MODELS = {m['demo'].__name__: m for m in MODELS} # for m in MODELS.values(): # m['demofunc'] = lambda x: m['demo'](x, m['nlins']) PICKLE = os.path.join(os.path.dirname( os.path.realpath(__file__)), "test_sfs.pickle") def generate_sfs(): with open(PICKLE, "rb") as sfs_dict_file: sfs_dict = pickle.load(sfs_dict_file) ret = [] for k, v in sfs_dict.items(): m_name, params, sampled_sfs = k n_lin = MODELS[m_name]['nlins'] demo = MODELS[m_name]['demo'](np.array(params)) yield m_name, v, demo, sampled_sfs @pytest.mark.parametrize("m_name,v,demo,sampled_sfs", generate_sfs()) def test_generated_cases(m_name, v, demo, sampled_sfs): compute_stats(demo, sampled_sfs, *v) def compute_stats(demo, sampled_sfs, true_sfs=None, true_branch_len=None): sampled_sfs = momi.site_freq_spectrum(demo.leafs, to_dict(sampled_sfs)) demo.set_data(sampled_sfs, length=1) demo.set_mut_rate(1) exp_branch_len = demo.expected_branchlen() exp_sfs = demo.expected_sfs() configs = sorted([tuple(map(tuple, c)) for c in sampled_sfs.configs]) exp_sfs = np.array([exp_sfs[c] for c in configs]) # use ms units exp_branch_len = exp_branch_len / 4.0 / demo.N_e exp_sfs = exp_sfs / 4.0 / demo.N_e if true_sfs is not None: assert np.allclose(true_sfs, exp_sfs, rtol=1e-4) if true_branch_len is not None: assert np.allclose(true_branch_len, exp_branch_len, rtol=1e-4) return exp_sfs, exp_branch_len def from_dict(sampled_sfs): # make it hashable return tuple([tuple(locus.items()) for locus in sampled_sfs]) def to_dict(sampled_sfs): # make it a dictionary return [dict(locus) for locus in sampled_sfs] if __name__ == "__main__": # TODO check this simulation code still works! results = {} for m_name, m_val in MODELS.items(): print("# GENERATING %s" % m_name) for i in range(10): x = np.random.normal(size=m_val['params']) demo = m_val['demo'](x, m_val['nlins']) demo.demo_hist = demo.demo_hist.rescaled() #seg_sites = simulate_ms( # ms_path, demo.demo_hist._get_multipop_moran(demo.pops, demo.n), num_loci=100, mut_rate=1.0) # TODO fix this simulation code!!! num_bases = 1000 mu = 1. n_loci = 100 sfs = demo.demo_hist.simulate_data( demo.pops, demo.n, mutation_rate=mu/num_bases, recombination_rate=0, length=num_bases, num_replicates=n_loci).sfs sampled_sfs = from_dict(sfs.to_dict(vector=True)) results[(m_name, tuple(x), sampled_sfs) ] = compute_stats(demo, sampled_sfs) if len(sys.argv) == 2 and sys.argv[1] == "generate": pickle.dump(results, open(PICKLE, "wb"))