Python numpy.testing.assert_equal() Examples

The following are 30 code examples of numpy.testing.assert_equal(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module numpy.testing , or try the search function .
Example #1
Source File: test_geometry.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_Grid2D(x_range, y_range):
    grid = Grid2D(x_range, y_range, step=1, grid_type='rectangular')
    npt.assert_equal(grid.x_range, x_range)
    npt.assert_equal(grid.y_range, y_range)
    npt.assert_equal(grid.step, 1)
    npt.assert_equal(grid.type, 'rectangular')

    # Grid is created with indexing='xy', so check coordinates:
    npt.assert_equal(grid.x.shape,
                     (np.abs(np.diff(y_range)) + 1,
                      np.abs(np.diff(x_range)) + 1))
    npt.assert_equal(grid.x.shape, grid.y.shape)
    npt.assert_equal(grid.x.shape, grid.shape)
    npt.assert_almost_equal(grid.x[0, 0], x_range[0])
    npt.assert_almost_equal(grid.x[0, -1], x_range[1])
    npt.assert_almost_equal(grid.x[-1, 0], x_range[0])
    npt.assert_almost_equal(grid.x[-1, -1], x_range[1])
    npt.assert_almost_equal(grid.y[0, 0], y_range[0])
    npt.assert_almost_equal(grid.y[0, -1], y_range[0])
    npt.assert_almost_equal(grid.y[-1, 0], y_range[1])
    npt.assert_almost_equal(grid.y[-1, -1], y_range[1]) 
Example #2
Source File: test_prima.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_PhotovoltaicPixel():
    electrode = PhotovoltaicPixel(0, 1, 2, 3, 4)
    npt.assert_almost_equal(electrode.x, 0)
    npt.assert_almost_equal(electrode.y, 1)
    npt.assert_almost_equal(electrode.z, 2)
    npt.assert_almost_equal(electrode.r, 3)
    npt.assert_almost_equal(electrode.a, 4)
    # Slots:
    npt.assert_equal(hasattr(electrode, '__slots__'), True)
    npt.assert_equal(hasattr(electrode, '__dict__'), False)
    # Plots:
    ax = electrode.plot()
    npt.assert_equal(len(ax.texts), 0)
    npt.assert_equal(len(ax.patches), 2)
    npt.assert_equal(isinstance(ax.patches[0], RegularPolygon), True)
    npt.assert_equal(isinstance(ax.patches[1], Circle), True)
    PhotovoltaicPixel(0, 1, 2, 3, 4) 
Example #3
Source File: test_electrodes.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_HexElectrode():
    with pytest.raises(TypeError):
        HexElectrode(0, 0, 0, [1, 2])
    with pytest.raises(TypeError):
        HexElectrode(0, np.array([0, 1]), 0, 1)
    # Invalid radius:
    with pytest.raises(ValueError):
        HexElectrode(0, 0, 0, -5)
    # Check params:
    electrode = HexElectrode(0, 1, 2, 100)
    npt.assert_almost_equal(electrode.x, 0)
    npt.assert_almost_equal(electrode.y, 1)
    npt.assert_almost_equal(electrode.z, 2)
    npt.assert_almost_equal(electrode.a, 100)
    # Slots:
    npt.assert_equal(hasattr(electrode, '__slots__'), True)
    npt.assert_equal(hasattr(electrode, '__dict__'), False)
    # Plots:
    ax = electrode.plot()
    npt.assert_equal(len(ax.texts), 0)
    npt.assert_equal(len(ax.patches), 1)
    npt.assert_equal(isinstance(ax.patches[0], RegularPolygon), True) 
Example #4
Source File: test_electrodes.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_SquareElectrode():
    with pytest.raises(TypeError):
        SquareElectrode(0, 0, 0, [1, 2])
    with pytest.raises(TypeError):
        SquareElectrode(0, np.array([0, 1]), 0, 1)
    # Invalid radius:
    with pytest.raises(ValueError):
        SquareElectrode(0, 0, 0, -5)
    # Check params:
    electrode = SquareElectrode(0, 1, 2, 100)
    npt.assert_almost_equal(electrode.x, 0)
    npt.assert_almost_equal(electrode.y, 1)
    npt.assert_almost_equal(electrode.z, 2)
    npt.assert_almost_equal(electrode.a, 100)
    # Slots:
    npt.assert_equal(hasattr(electrode, '__slots__'), True)
    npt.assert_equal(hasattr(electrode, '__dict__'), False)
    # Plots:
    ax = electrode.plot()
    npt.assert_equal(len(ax.texts), 0)
    npt.assert_equal(len(ax.patches), 1)
    npt.assert_equal(isinstance(ax.patches[0], Rectangle), True) 
Example #5
Source File: test_kproxy_supercell_hf.py    From pyscf with Apache License 2.0 6 votes vote down vote up
def test_class(self):
        """Tests container behavior."""
        model = kproxy_supercell.TDProxy(self.model_krhf, "hf", [self.k, 1, 1], density_fitting_hf)
        model.nroots = self.td_model_krhf.nroots
        assert not model.fast
        model.kernel()
        testing.assert_allclose(model.e, self.td_model_krhf.e, atol=1e-5)
        # Test real
        testing.assert_allclose(model.e.imag, 0, atol=1e-8)

        nocc = nvirt = 4
        testing.assert_equal(model.xy.shape, (len(model.e), 2, self.k, self.k, nocc, nvirt))

        # Test only non-degenerate roots
        d = abs(model.e[1:] - model.e[:-1]) < 1e-8
        d = numpy.logical_or(numpy.concatenate(([False], d)), numpy.concatenate((d, [False])))
        d = numpy.logical_not(d)
        assert_vectors_close(self.td_model_krhf.xy[d], model.xy[d], atol=1e-5) 
Example #6
Source File: test_bva.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_BVA24_stim():
    # Assign a stimulus:
    implant = BVA24()
    implant.stim = {'1': 1}
    npt.assert_equal(implant.stim.electrodes, ['1'])
    npt.assert_equal(implant.stim.time, None)
    npt.assert_equal(implant.stim.data, [[1]])

    # You can also assign the stimulus in the constructor:
    BVA24(stim={'1': 1})
    npt.assert_equal(implant.stim.electrodes, ['1'])
    npt.assert_equal(implant.stim.time, None)
    npt.assert_equal(implant.stim.data, [[1]])

    # Set a stimulus via array:
    implant = BVA24(stim=np.ones(35))
    npt.assert_equal(implant.stim.shape, (35, 1))
    npt.assert_almost_equal(implant.stim.data, 1) 
Example #7
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_fetch_url(tmp_data_dir):
    url1 = 'https://www.nature.com/articles/s41598-019-45416-4.pdf'
    file_path1 = os.path.join(tmp_data_dir, 'paper1.pdf')
    paper_checksum1 = 'e8a2db25916cdd15a4b7be75081ef3e57328fa5f335fb4664d1fb7090dcd6842'
    fetch_url(url1, file_path1, remote_checksum=paper_checksum1)
    npt.assert_equal(os.path.exists(file_path1), True)

    url2 = 'https://bionicvisionlab.org/publication/2019-optimal-surgical-placement/2019-optimal-surgical-placement.pdf'
    file_path2 = os.path.join(tmp_data_dir, 'paper2.pdf')
    paper_checksum2 = 'e2d0cbecc9c2826f66f60576b44fe18ad6a635d394ae02c3f528b89cffcd9450'
    # Use wrong checksum:
    with pytest.raises(IOError):
        fetch_url(url2, file_path2, remote_checksum=paper_checksum1)
    # Use correct checksum:
    fetch_url(url2, file_path2, remote_checksum=paper_checksum2)
    npt.assert_equal(os.path.exists(file_path2), True) 
Example #8
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_BaseModel():
    # Test PrettyPrint:
    model = ValidBaseModel()
    npt.assert_equal(str(model), 'ValidBaseModel(a=1, b=2)')

    # Can overwrite default values:
    model = ValidBaseModel(b=3)
    npt.assert_almost_equal(model.b, 3)

    # Cannot add more attributes:
    with pytest.raises(FreezeError):
        model.c = 3

    # Check the build switch:
    npt.assert_equal(model.is_built, False)
    model.build(a=3)
    npt.assert_almost_equal(model.a, 3)
    npt.assert_equal(model.is_built, True)

    # Attributes must be in `get_default_params`:
    with pytest.raises(AttributeError):
        ValidBaseModel(c=3)
    with pytest.raises(AttributeError):
        ValidBaseModel().is_built = True 
Example #9
Source File: testing.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def assert_warns_msg(expected_warning, func, msg, *args, **kwargs):
    """Assert a call leads to a warning with a specific message

    Test whether a function call leads to a warning of type
    ``expected_warning`` with a message that contains the string ``msg``.

    Parameters
    ----------
    expected_warning : warning class
        The class of warning to be checked; e.g., DeprecationWarning
    func : object
        The class, method, property, or function to be called as\
        func(\*args, \*\*kwargs)
    msg : str
        The message or a substring of the message to test for.
    \*args : positional arguments to ``func``
    \*\*kwargs: keyword arguments to ``func``

    """
    with pytest.warns(expected_warning) as record:
        func(*args, **kwargs)
    npt.assert_equal(len(record), 1)
    if msg is not None:
        npt.assert_equal(msg in record[0].message.args[0], True) 
Example #10
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_Model_set_params():
    # SpatialModel, but no TemporalModel:
    model = Model(spatial=ValidSpatialModel())
    model.set_params({'xystep': 2.33})
    npt.assert_almost_equal(model.xystep, 2.33)
    npt.assert_almost_equal(model.spatial.xystep, 2.33)

    # TemporalModel, but no SpatialModel:
    model = Model(temporal=ValidTemporalModel())
    model.set_params({'dt': 2.33})
    npt.assert_almost_equal(model.dt, 2.33)
    npt.assert_almost_equal(model.temporal.dt, 2.33)

    # SpatialModel and TemporalModel:
    model = Model(spatial=ValidSpatialModel(), temporal=ValidTemporalModel())
    # Setting both using the convenience function:
    model.set_params({'xystep': 5, 'dt': 2.33})
    npt.assert_almost_equal(model.xystep, 5)
    npt.assert_almost_equal(model.spatial.xystep, 5)
    npt.assert_equal(hasattr(model.temporal, 'xystep'), False)
    npt.assert_almost_equal(model.dt, 2.33)
    npt.assert_almost_equal(model.temporal.dt, 2.33)
    npt.assert_equal(hasattr(model.spatial, 'dt'), False) 
Example #11
Source File: test_convolution.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_conv(mode, method):
    reload(convolution)
    # time vector for stimulus (long)
    stim_dur = 0.5  # seconds
    tsample = 0.001 / 1000
    t = np.arange(0, stim_dur, tsample)

    # stimulus (10 Hz anondic and cathodic pulse train)
    stim = np.zeros_like(t)
    stim[::1000] = 1
    stim[100::1000] = -1

    # kernel
    _, gg = gamma(1, 0.005, tsample)

    # make sure conv returns the same result as np.convolve for all modes:
    npconv = np.convolve(stim, gg, mode=mode)
    conv = convolution.conv(stim, gg, mode=mode, method=method)
    npt.assert_equal(conv.shape, npconv.shape)
    npt.assert_almost_equal(conv, npconv)

    with pytest.raises(ValueError):
        convolution.conv(gg, stim, mode="invalid")
    with pytest.raises(ValueError):
        convolution.conv(gg, stim, method="invalid") 
Example #12
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_gamma():
    tsample = 0.005 / 1000

    with pytest.raises(ValueError):
        t, g = gamma(0, 0.1, tsample)
    with pytest.raises(ValueError):
        t, g = gamma(2, -0.1, tsample)
    with pytest.raises(ValueError):
        t, g = gamma(2, 0.1, -tsample)

    for tau in [0.001, 0.01, 0.1]:
        for n in [1, 2, 5]:
            t, g = gamma(n, tau, tsample)
            npt.assert_equal(np.arange(0, t[-1] + tsample / 2.0, tsample), t)
            if n > 1:
                npt.assert_equal(g[0], 0.0)

            # Make sure area under the curve is normalized
            npt.assert_almost_equal(np.trapz(np.abs(g), dx=tsample), 1.0,
                                    decimal=2)

            # Make sure peak sits correctly
            npt.assert_almost_equal(g.argmax() * tsample, tau * (n - 1)) 
Example #13
Source File: test_extint128.py    From recruit with Apache License 2.0 6 votes vote down vote up
def test_safe_binop():
    # Test checked arithmetic routines

    ops = [
        (operator.add, 1),
        (operator.sub, 2),
        (operator.mul, 3)
    ]

    with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
        for xop, a, b in it:
            pyop, op = xop
            c = pyop(a, b)

            if not (INT64_MIN <= c <= INT64_MAX):
                assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
            else:
                d = mt.extint_safe_binop(a, b, op)
                if c != d:
                    # assert_equal is slow
                    assert_equal(d, c) 
Example #14
Source File: test_electrode_arrays.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_ElectrodeGrid___get_item__(gtype):
    grid = ElectrodeGrid((2, 4), 20, names=('A', '1'), type=gtype,
                         etype=DiskElectrode, r=20)
    npt.assert_equal(grid[0], grid['A1'])
    npt.assert_equal(grid[0, 0], grid['A1'])
    npt.assert_equal(grid[1], grid['A2'])
    npt.assert_equal(grid[0, 1], grid['A2'])
    npt.assert_equal(grid[['A1', 1, (0, 2)]],
                     [grid['A1'], grid['A2'], grid['A3']]) 
Example #15
Source File: test_electrodes.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_DiskElectrode():
    with pytest.raises(TypeError):
        DiskElectrode(0, 0, 0, [1, 2])
    with pytest.raises(TypeError):
        DiskElectrode(0, np.array([0, 1]), 0, 1)
    # Invalid radius:
    with pytest.raises(ValueError):
        DiskElectrode(0, 0, 0, -5)
    # Check params:
    electrode = DiskElectrode(0, 1, 2, 100)
    npt.assert_almost_equal(electrode.x, 0)
    npt.assert_almost_equal(electrode.y, 1)
    npt.assert_almost_equal(electrode.z, 2)
    # On the electrode surface (z=2, x^2+y^2<=100^2)
    npt.assert_almost_equal(electrode.electric_potential(0, 1, 2, 1), 1)
    npt.assert_almost_equal(electrode.electric_potential(30, -30, 2, 1), 1)
    npt.assert_almost_equal(electrode.electric_potential(0, 101, 2, 1), 1)
    npt.assert_almost_equal(electrode.electric_potential(0, -99, 2, 1), 1)
    npt.assert_almost_equal(electrode.electric_potential(100, 1, 2, 1), 1)
    npt.assert_almost_equal(electrode.electric_potential(-100, 1, 2, 1), 1)
    # Right off the surface (z=2, x^2+y^2>100^2)
    npt.assert_almost_equal(electrode.electric_potential(0, 102, 2, 1), 0.910,
                            decimal=3)
    npt.assert_almost_equal(electrode.electric_potential(0, -100, 2, 1), 0.910,
                            decimal=3)
    # Some distance away from the electrode (z>2):
    npt.assert_almost_equal(electrode.electric_potential(0, 1, 38, 1), 0.780,
                            decimal=3)
    # Slots:
    npt.assert_equal(hasattr(electrode, '__slots__'), True)
    npt.assert_equal(hasattr(electrode, '__dict__'), False)
    # Plots:
    ax = electrode.plot()
    npt.assert_equal(len(ax.texts), 0)
    npt.assert_equal(len(ax.patches), 1)
    npt.assert_equal(isinstance(ax.patches[0], Circle), True) 
Example #16
Source File: test_beyeler2019.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_AxonMapModel_predict_percept(engine):
    model = AxonMapModel(xystep=0.55, axlambda=100, thresh_percept=0,
                         xrange=(-20, 20), yrange=(-15, 15),
                         engine=engine)
    model.build()
    # Single-electrode stim:
    img_stim = np.zeros(60)
    img_stim[47] = 1
    percept = model.predict_percept(ArgusII(stim=img_stim))
    # Single bright pixel, rest of arc is less bright:
    npt.assert_equal(np.sum(percept.data > 0.8), 1)
    npt.assert_equal(np.sum(percept.data > 0.6), 3)
    npt.assert_equal(np.sum(percept.data > 0.1), 21)
    npt.assert_equal(np.sum(percept.data > 0.0001), 70)
    # Overall only a few bright pixels:
    npt.assert_almost_equal(np.sum(percept.data), 8.0898, decimal=3)
    # Brightest pixel is in lower right:
    npt.assert_almost_equal(percept.data[33, 46, 0], np.max(percept.data))
    # Top half is empty:
    npt.assert_almost_equal(np.sum(percept.data[:27, :, 0]), 0)
    # Same for lower band:
    npt.assert_almost_equal(np.sum(percept.data[39:, :, 0]), 0)

    # Full Argus II with small lambda: 60 bright spots
    model = AxonMapModel(engine='serial', xystep=1, rho=100, axlambda=40,
                         xrange=(-20, 20), yrange=(-15, 15))
    model.build()
    percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    # Most spots are pretty bright, but there are 2 dimmer ones (due to their
    # location on the retina):
    npt.assert_equal(np.sum(percept.data > 0.5), 28)
    npt.assert_equal(np.sum(percept.data > 0.275), 58)

    # Model gives same outcome as Spatial:
    spatial = AxonMapSpatial(engine='serial', xystep=1, rho=100, axlambda=40)
    spatial.build()
    spatial_percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_almost_equal(percept.data, spatial_percept.data)
    npt.assert_equal(percept.time, None) 
Example #17
Source File: test_electrode_arrays.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_ElectrodeArray_add_electrode():
    earray = ElectrodeArray([])
    npt.assert_equal(earray.n_electrodes, 0)

    with pytest.raises(TypeError):
        earray.add_electrode('A01', ElectrodeArray([]))

    # Add an electrode:
    key0 = 'A04'
    earray.add_electrode(key0, PointSource(0, 1, 2))
    npt.assert_equal(earray.n_electrodes, 1)
    # Both numeric and string index should work:
    for key in [key0, 0]:
        npt.assert_equal(isinstance(earray[key], PointSource), True)
        npt.assert_almost_equal(earray[key].x, 0)
        npt.assert_almost_equal(earray[key].y, 1)
        npt.assert_almost_equal(earray[key].z, 2)
    with pytest.raises(ValueError):
        # Can't add the same electrode twice:
        earray.add_electrode(key0, PointSource(0, 1, 2))

    # Add another electrode:
    key1 = 'A01'
    earray.add_electrode(key1, DiskElectrode(4, 5, 6, 7))
    npt.assert_equal(earray.n_electrodes, 2)
    # Both numeric and string index should work:
    for key in [key1, 1]:
        npt.assert_equal(isinstance(earray[key], DiskElectrode), True)
        npt.assert_almost_equal(earray[key].x, 4)
        npt.assert_almost_equal(earray[key].y, 5)
        npt.assert_almost_equal(earray[key].z, 6)
        npt.assert_almost_equal(earray[key].r, 7)

    # We can also get a list of electrodes:
    for keys in [[key0, key1], [0, key1], [key0, 1], [0, 1]]:
        selected = earray[keys]
        npt.assert_equal(isinstance(selected, list), True)
        npt.assert_equal(isinstance(selected[0], PointSource), True)
        npt.assert_equal(isinstance(selected[1], DiskElectrode), True) 
Example #18
Source File: test_beyeler2019.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_fetch_beyeler2019():
    data = datasets.fetch_beyeler2019(shuffle=False)

    npt.assert_equal(isinstance(data, pd.DataFrame), True)
    columns = ['subject', 'amp', 'area', 'compactness', 'date', 'eccentricity',
               'electrode', 'filename', 'freq', 'image', 'orientation', 'pdur',
               'stim_class', 'x_center', 'y_center', 'img_shape']
    for expected_col in columns:
        npt.assert_equal(expected_col in data.columns, True)

    npt.assert_equal(data.shape, (400, 16))
    npt.assert_equal(data.subject.unique(), ['S1', 'S2', 'S3', 'S4'])
    npt.assert_equal(list(data[data.subject == 'S1'].img_shape.unique()[0]),
                     [384, 384])
    npt.assert_equal(list(data[data.subject != 'S1'].img_shape.unique()[0]),
                     [768, 1024])

    # Shuffle dataset (index will always be range(400), but rows are shuffled):
    data = datasets.fetch_beyeler2019(shuffle=True, random_state=42)
    npt.assert_equal(data.loc[0, 'subject'], 'S3')
    npt.assert_equal(data.loc[0, 'electrode'], 'A2')
    npt.assert_equal(data.loc[399, 'subject'], 'S2')
    npt.assert_equal(data.loc[399, 'electrode'], 'D4')

    with mock.patch.dict("sys.modules", {"pandas": {}}):
        with pytest.raises(ImportError):
            reload(datasets)
            datasets.fetch_beyeler2019() 
Example #19
Source File: test_electrode_arrays.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_ElectrodeGrid_get_params(gtype):
    # When the electrode_type is 'DiskElectrode'
    # test the default value
    egrid = ElectrodeGrid((2, 3), 40, type=gtype, etype=DiskElectrode, r=20)
    npt.assert_equal(egrid.shape, (2, 3))
    npt.assert_equal(egrid.type, gtype) 
Example #20
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_Model_predict_percept():
    # A None Model has nothing to build, nothing to perceive:
    model = Model()
    npt.assert_equal(model.predict_percept(ArgusI()), None)
    npt.assert_equal(model.predict_percept(ArgusI(stim={'A1': 1})), None)
    npt.assert_equal(model.predict_percept(ArgusI(stim={'A1': 1}),
                                           t_percept=[0, 1]), None)

    # Just the spatial model:
    model = Model(spatial=ValidSpatialModel()).build()
    npt.assert_equal(model.predict_percept(ArgusI()), None)
    # Just the temporal model:
    model = Model(temporal=ValidTemporalModel()).build()
    npt.assert_equal(model.predict_percept(ArgusI()), None)
    # Both spatial and temporal:

    # Invalid calls:
    model = Model(spatial=ValidSpatialModel(), temporal=ValidTemporalModel())
    with pytest.raises(NotBuiltError):
        # Must call build first:
        model.predict_percept(ArgusI())
    model.build()
    with pytest.raises(ValueError):
        # Cannot request t_percepts that are not multiples of dt:
        model.predict_percept(ArgusI(stim={'A1': np.ones(16)}),
                              t_percept=[0.1, 0.11])
    with pytest.raises(ValueError):
        # Has temporal model but stim.time is None:
        ValidTemporalModel().predict_percept(Stimulus(3))
    with pytest.raises(ValueError):
        # stim.time==None but requesting t_percept != None
        model.predict_percept(ArgusI(stim=np.ones(16)),
                              t_percept=[0, 1, 2])
    with pytest.raises(TypeError):
        # Must pass an implant:
        model.predict_percept(Stimulus(3)) 
Example #21
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_Model_build():
    # A None model:
    model = Model()
    # Nothing to build, so `is_built` is always True (we want to be able to
    # call `predict_percept`):
    npt.assert_equal(model.is_built, True)
    model.build()
    npt.assert_equal(model.is_built, True)

    # SpatialModel, but no TemporalModel:
    model = Model(spatial=ValidSpatialModel())
    npt.assert_equal(model.is_built, False)
    model.build()
    npt.assert_equal(model.is_built, True)

    # TemporalModel, but no SpatialModel:
    model = Model(temporal=ValidTemporalModel())
    npt.assert_equal(model.is_built, False)
    model.build()
    npt.assert_equal(model.is_built, True)

    # SpatialModel and TemporalModel:
    model = Model(spatial=ValidSpatialModel(), temporal=ValidTemporalModel())
    npt.assert_equal(model.is_built, False)
    model.build()
    npt.assert_equal(model.is_built, True) 
Example #22
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_Percept():
    # Automatic axes:
    ndarray = np.arange(15).reshape((3, 5, 1))
    percept = Percept(ndarray, metadata='meta')
    npt.assert_equal(percept.shape, ndarray.shape)
    npt.assert_equal(percept.metadata, 'meta')
    npt.assert_equal(hasattr(percept, 'xdva'), True)
    npt.assert_almost_equal(percept.xdva, np.arange(ndarray.shape[1]))
    npt.assert_equal(hasattr(percept, 'ydva'), True)
    npt.assert_almost_equal(percept.ydva, np.arange(ndarray.shape[0]))
    # Singleton dimensions can be None:
    npt.assert_equal(hasattr(percept, 'time'), True)
    npt.assert_equal(percept.time, None)

    # Specific labels:
    percept = Percept(ndarray, time=0.4)
    npt.assert_almost_equal(percept.time, [0.4])
    percept = Percept(ndarray, time=[0.4])
    npt.assert_almost_equal(percept.time, [0.4])

    # Labels from a grid.
    y_range = (-1, 1)
    x_range = (-2, 2)
    grid = Grid2D(x_range, y_range)
    percept = Percept(ndarray, space=grid)
    npt.assert_almost_equal(percept.xdva, grid._xflat)
    npt.assert_almost_equal(percept.ydva, grid._yflat)
    npt.assert_equal(percept.time, None)
    grid = Grid2D(x_range, y_range)
    percept = Percept(ndarray, space=grid, time=0)
    npt.assert_almost_equal(percept.xdva, grid._xflat)
    npt.assert_almost_equal(percept.ydva, grid._yflat)
    npt.assert_almost_equal(percept.time, [0])

    with pytest.raises(TypeError):
        Percept(ndarray, space={'x': [0, 1, 2], 'y': [0, 1, 2, 3, 4]}) 
Example #23
Source File: test_nanduri2012.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_Nanduri2012Model():
    model = Nanduri2012Model(engine='serial', xystep=5)
    npt.assert_equal(hasattr(model, 'has_time'), True)
    npt.assert_equal(model.has_time, True)

    # User can set `dt`:
    model.temporal.dt = 1e-5
    npt.assert_almost_equal(model.dt, 1e-5)
    npt.assert_almost_equal(model.temporal.dt, 1e-5)
    model.build(dt=3e-4)
    npt.assert_almost_equal(model.dt, 3e-4)
    npt.assert_almost_equal(model.temporal.dt, 3e-4)

    # User cannot add more model parameters:
    with pytest.raises(FreezeError):
        model.rho = 100

    # Some parameters exist in both spatial and temporal model. We can set them
    # both at once:
    th = 0.512
    model.set_params({'thresh_percept': th})
    npt.assert_almost_equal(model.spatial.thresh_percept, th)
    npt.assert_almost_equal(model.temporal.thresh_percept, th)
    # or individually:
    model.temporal.thresh_percept = 2 * th
    npt.assert_almost_equal(model.temporal.thresh_percept, 2 * th) 
Example #24
Source File: test_parallel.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_parfor(engine, scheduler):
    my_array = np.arange(100).reshape(10, 10)
    i, j = np.random.randint(0, 9, 2)
    my_list = list(my_array.ravel())

    expected_00 = power_it(my_array[0, 0])
    expected_ij = power_it(my_array[i, j])

    with pytest.raises(ValueError):
        parallel.parfor(power_it, my_list, engine='unknown')
    with pytest.raises(ValueError):
        parallel.parfor(power_it, my_list, engine='dask', scheduler='unknown')

    # `backend` only relevant for dask, will be ignored for others
    # and should thus still give the right result
    calculated_00 = parallel.parfor(power_it, my_list, engine=engine,
                                    scheduler=scheduler,
                                    out_shape=my_array.shape)[0, 0]
    calculated_ij = parallel.parfor(power_it, my_list, engine=engine,
                                    scheduler=scheduler,
                                    out_shape=my_array.shape)[i, j]

    npt.assert_equal(expected_00, calculated_00)
    npt.assert_equal(expected_ij, calculated_ij)

    with mock.patch.dict("sys.modules", {'dask': {}}):
        reload(parallel)
        with pytest.raises(ImportError):
            parallel.parfor(power_it, my_list, engine='dask',
                            out_shape=my_array.shape)[0, 0]
    reload(parallel) 
Example #25
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_Percept__iter__():
    ndarray = np.zeros((2, 4, 3))
    ndarray[..., 1] = 1
    ndarray[..., 2] = 2
    percept = Percept(ndarray)
    for i, frame in enumerate(percept):
        npt.assert_equal(frame.shape, (2, 4))
        npt.assert_almost_equal(frame, i) 
Example #26
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_Percept_plot():
    y_range = (-1, 1)
    x_range = (-2, 2)
    grid = Grid2D(x_range, y_range)
    percept = Percept(np.arange(15).reshape((3, 5, 1)), space=grid)

    # Basic usage of pcolor:
    ax = percept.plot(kind='pcolor')
    npt.assert_equal(isinstance(ax, Subplot), True)
    npt.assert_almost_equal(ax.axis(), [*x_range, *y_range])
    frame = percept.get_brightest_frame()
    npt.assert_almost_equal(ax.collections[0].get_clim(),
                            [frame.min(), frame.max()])

    # Basic usage of hex:
    ax = percept.plot(kind='hex')
    npt.assert_equal(isinstance(ax, Subplot), True)
    npt.assert_almost_equal(ax.axis(), [percept.xdva[0], percept.xdva[-1],
                                        percept.ydva[0], percept.ydva[-1]])
    npt.assert_almost_equal(ax.collections[0].get_clim(),
                            [percept.data[..., 0].min(),
                             percept.data[..., 0].max()])

    # Verify color map:
    npt.assert_equal(ax.collections[0].cmap, plt.cm.gray)

    # Specify figsize:
    ax = percept.plot(kind='pcolor', figsize=(6, 4))
    npt.assert_almost_equal(ax.figure.get_size_inches(), (6, 4))

    # Invalid calls:
    with pytest.raises(ValueError):
        percept.plot(kind='invalid')
    with pytest.raises(TypeError):
        percept.plot(ax='invalid')

    # TODO
    with pytest.raises(NotImplementedError):
        percept.plot(time=3.3) 
Example #27
Source File: test_base.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_Percept_save():
    ndarray = np.arange(256, dtype=np.float32).repeat(31).reshape((-1, 16, 16))
    percept = Percept(ndarray.transpose((2, 0, 1)))

    # Save multiple frames as a gif or movie:
    for fname in ['test.mp4', 'test.avi', 'test.mov', 'test.wmv', 'test.gif']:
        print(fname)
        percept.save(fname)
        npt.assert_equal(os.path.isfile(fname), True)
        # Normalized to [0, 255] with some loss of precision:
        mov = mimread(fname)
        npt.assert_equal(np.min(mov) <= 2, True)
        npt.assert_equal(np.max(mov) >= 250, True)
        os.remove(fname)

    # Cannot save multiple frames image:
    fname = 'test.jpg'
    with pytest.raises(ValueError):
        percept.save(fname)

    # But, can save single frame as image:
    percept = Percept(ndarray[..., :1])
    for fname in ['test.jpg', 'test.png', 'test.tif', 'test.gif']:
        percept.save(fname)
        npt.assert_equal(os.path.isfile(fname), True)
        img = img_as_float(imread(fname))
        npt.assert_almost_equal(np.min(img), 0, decimal=3)
        npt.assert_almost_equal(np.max(img), 1.0, decimal=3)
        os.remove(fname) 
Example #28
Source File: test_geometry.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_Grid2D_plot():
    grid = Grid2D((-20, 20), (-40, 40), step=0.5)
    ax = grid.plot()
    npt.assert_equal(isinstance(ax, Axes), True) 
Example #29
Source File: test_xrft.py    From xrft with MIT License 5 votes vote down vote up
def test_isotropic_ps():
    """Test data with extra coordinates"""
    da = xr.DataArray(np.random.rand(2,5,16,32),
                  dims=['time','z','y','x'],
                  coords={'time': np.array(['2019-04-18', '2019-04-19'],
                                          dtype='datetime64'),
                         'zz': ('z',np.arange(5)), 'z': np.arange(5),
                         'y': np.arange(16), 'x': np.arange(32)})
    with pytest.raises(ValueError):
        xrft.isotropic_power_spectrum(da, dim=['z','y','x'])
    iso_ps = xrft.isotropic_power_spectrum(da, dim=['y','x'])
    npt.assert_equal(
            np.ma.masked_invalid(iso_ps.isel(freq_r=slice(1,None))).mask.sum(),
            0.) 
Example #30
Source File: test_beyeler2019.py    From pulse2percept with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_ScoreboardModel():
    # ScoreboardModel automatically sets `rho`:
    model = ScoreboardModel(engine='serial', xystep=5)
    npt.assert_equal(model.has_space, True)
    npt.assert_equal(model.has_time, False)
    npt.assert_equal(hasattr(model.spatial, 'rho'), True)

    # User can set `rho`:
    model.rho = 123
    npt.assert_equal(model.rho, 123)
    npt.assert_equal(model.spatial.rho, 123)
    model.build(rho=987)
    npt.assert_equal(model.rho, 987)
    npt.assert_equal(model.spatial.rho, 987)

    # Nothing in, None out:
    npt.assert_equal(model.predict_percept(ArgusI()), None)

    # Zero in = zero out:
    implant = ArgusI(stim=np.zeros(16))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Multiple frames are processed independently:
    model = ScoreboardModel(engine='serial', rho=200, xystep=5,
                            xrange=(-20, 20), yrange=(-15, 15))
    model.build()
    percept = model.predict_percept(ArgusI(stim={'A1': [1, 2]}))
    npt.assert_equal(percept.shape, list(model.grid.x.shape) + [2])
    pmax = percept.data.max(axis=(0, 1))
    npt.assert_almost_equal(percept.data[2, 3, :], pmax)
    npt.assert_almost_equal(pmax[1] / pmax[0], 2.0)
    npt.assert_almost_equal(percept.time, [0, 1])