Python pytest.xfail() Examples

The following are 30 code examples for showing how to use pytest.xfail(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module pytest , or try the search function .

Example 1
Project: py   Author: pytest-dev   File: test_svnwc.py    License: MIT License 6 votes vote down vote up
def test_status_update(self, path1):
        # not a mark because the global "pytestmark" will end up overwriting a mark here
        pytest.xfail("svn-1.7 has buggy 'status --xml' output")
        r = path1
        try:
            r.update(rev=1)
            s = r.status(updates=1, rec=1)
            # Comparing just the file names, because paths are unpredictable
            # on Windows. (long vs. 8.3 paths)
            import pprint
            pprint.pprint(s.allpath())
            assert r.join('anotherfile').basename in [item.basename for
                                                    item in s.update_available]
            #assert len(s.update_available) == 1
        finally:
            r.update() 
Example 2
Project: qutebrowser   Author: qutebrowser   File: test_navigate.py    License: GNU General Public License v3.0 6 votes vote down vote up
def test_incdec(self, incdec, value, url, config_stub):
        if (value == '{}foo' and
                url == 'http://example.com/path with {} spaces'):
            pytest.xfail("https://github.com/qutebrowser/qutebrowser/issues/4917")

        config_stub.val.url.incdec_segments = ['host', 'path', 'query',
                                               'anchor']

        # The integer used should not affect test output, as long as it's
        # bigger than 1
        # 20 was chosen by dice roll, guaranteed to be random
        base_value = value.format(20)
        if incdec == 'increment':
            expected_value = value.format(21)
        else:
            expected_value = value.format(19)

        base_url = QUrl(url.format(base_value))
        expected_url = QUrl(url.format(expected_value))

        assert navigate.incdec(base_url, 1, incdec) == expected_url 
Example 3
Project: recruit   Author: Frank-qlu   File: test_coercion.py    License: Apache License 2.0 6 votes vote down vote up
def test_setitem_series_bool(self, val, exp_dtype):
        obj = pd.Series([True, False, True, False])
        assert obj.dtype == np.bool

        if exp_dtype is np.int64:
            exp = pd.Series([True, True, True, False])
            self._assert_setitem_series_conversion(obj, val, exp, np.bool)
            pytest.xfail("TODO_GH12747 The result must be int")
        elif exp_dtype is np.float64:
            exp = pd.Series([True, True, True, False])
            self._assert_setitem_series_conversion(obj, val, exp, np.bool)
            pytest.xfail("TODO_GH12747 The result must be float")
        elif exp_dtype is np.complex128:
            exp = pd.Series([True, True, True, False])
            self._assert_setitem_series_conversion(obj, val, exp, np.bool)
            pytest.xfail("TODO_GH12747 The result must be complex")

        exp = pd.Series([True, val, True, False])
        self._assert_setitem_series_conversion(obj, val, exp, exp_dtype) 
Example 4
Project: recruit   Author: Frank-qlu   File: test_coercion.py    License: Apache License 2.0 6 votes vote down vote up
def test_insert_index_datetimes(self, fill_val, exp_dtype):
        obj = pd.DatetimeIndex(['2011-01-01', '2011-01-02', '2011-01-03',
                                '2011-01-04'], tz=fill_val.tz)
        assert obj.dtype == exp_dtype

        exp = pd.DatetimeIndex(['2011-01-01', fill_val.date(), '2011-01-02',
                                '2011-01-03', '2011-01-04'], tz=fill_val.tz)
        self._assert_insert_conversion(obj, fill_val, exp, exp_dtype)

        msg = "Passed item and index have different timezone"
        if fill_val.tz:
            with pytest.raises(ValueError, match=msg):
                obj.insert(1, pd.Timestamp('2012-01-01'))

        with pytest.raises(ValueError, match=msg):
            obj.insert(1, pd.Timestamp('2012-01-01', tz='Asia/Tokyo'))

        msg = "cannot insert DatetimeIndex with incompatible label"
        with pytest.raises(TypeError, match=msg):
            obj.insert(1, 1)

        pytest.xfail("ToDo: must coerce to object") 
Example 5
Project: recruit   Author: Frank-qlu   File: test_coercion.py    License: Apache License 2.0 6 votes vote down vote up
def test_replace_series_datetime_tz(self):
        how = 'series'
        from_key = 'datetime64[ns, US/Eastern]'
        to_key = 'timedelta64[ns]'

        index = pd.Index([3, 4], name='xxx')
        obj = pd.Series(self.rep[from_key], index=index, name='yyy')
        assert obj.dtype == from_key

        if how == 'dict':
            replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
        elif how == 'series':
            replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
        else:
            raise ValueError

        result = obj.replace(replacer)
        exp = pd.Series(self.rep[to_key], index=index, name='yyy')
        assert exp.dtype == to_key

        tm.assert_series_equal(result, exp)

    # TODO(jreback) commented out to only have a single xfail printed 
Example 6
Project: recruit   Author: Frank-qlu   File: test_transform.py    License: Apache License 2.0 6 votes vote down vote up
def test_transform_numeric_ret(cols, exp, comp_func, agg_func):
    if agg_func == 'size' and isinstance(cols, list):
        pytest.xfail("'size' transformation not supported with "
                     "NDFrameGroupy")

    # GH 19200
    df = pd.DataFrame(
        {'a': pd.date_range('2018-01-01', periods=3),
         'b': range(3),
         'c': range(7, 10)})

    result = df.groupby('b')[cols].transform(agg_func)

    if agg_func == 'rank':
        exp = exp.astype('float')

    comp_func(result, exp) 
Example 7
Project: recruit   Author: Frank-qlu   File: test_timedelta64.py    License: Apache License 2.0 6 votes vote down vote up
def test_td64arr_mod_int(self, box_with_array):
        tdi = timedelta_range('1 ns', '10 ns', periods=10)
        tdarr = tm.box_expected(tdi, box_with_array)

        expected = TimedeltaIndex(['1 ns', '0 ns'] * 5)
        expected = tm.box_expected(expected, box_with_array)

        result = tdarr % 2
        tm.assert_equal(result, expected)

        with pytest.raises(TypeError):
            2 % tdarr

        if box_with_array is pd.DataFrame:
            pytest.xfail("DataFrame does not have __divmod__ or __rdivmod__")

        result = divmod(tdarr, 2)
        tm.assert_equal(result[1], expected)
        tm.assert_equal(result[0], tdarr // 2) 
Example 8
Project: recruit   Author: Frank-qlu   File: test_timedelta64.py    License: Apache License 2.0 6 votes vote down vote up
def test_td64arr_rmod_tdscalar(self, box_with_array, three_days):
        tdi = timedelta_range('1 Day', '9 days')
        tdarr = tm.box_expected(tdi, box_with_array)

        expected = ['0 Days', '1 Day', '0 Days'] + ['3 Days'] * 6
        expected = TimedeltaIndex(expected)
        expected = tm.box_expected(expected, box_with_array)

        result = three_days % tdarr
        tm.assert_equal(result, expected)

        if box_with_array is pd.DataFrame:
            pytest.xfail("DataFrame does not have __divmod__ or __rdivmod__")

        result = divmod(three_days, tdarr)
        tm.assert_equal(result[1], expected)
        tm.assert_equal(result[0], three_days // tdarr)

    # ------------------------------------------------------------------
    # Operations with invalid others 
Example 9
Project: recruit   Author: Frank-qlu   File: test_datetime64.py    License: Apache License 2.0 6 votes vote down vote up
def test_dt64arr_aware_sub_dt64ndarray_raises(self, tz_aware_fixture,
                                                  box_with_array):
        if box_with_array is pd.DataFrame:
            pytest.xfail("FIXME: ValueError with transpose; "
                         "alignment error without")

        tz = tz_aware_fixture
        dti = pd.date_range('2016-01-01', periods=3, tz=tz)
        dt64vals = dti.values

        dtarr = tm.box_expected(dti, box_with_array)

        with pytest.raises(TypeError):
            dtarr - dt64vals
        with pytest.raises(TypeError):
            dt64vals - dtarr

    # -------------------------------------------------------------
    # Addition of datetime-like others (invalid) 
Example 10
Project: recruit   Author: Frank-qlu   File: test_datetime64.py    License: Apache License 2.0 6 votes vote down vote up
def test_dt64arr_add_dt64ndarray_raises(self, tz_naive_fixture,
                                            box_with_array):
        if box_with_array is pd.DataFrame:
            pytest.xfail("FIXME: ValueError with transpose; "
                         "alignment error without")

        tz = tz_naive_fixture
        dti = pd.date_range('2016-01-01', periods=3, tz=tz)
        dt64vals = dti.values

        dtarr = tm.box_expected(dti, box_with_array)

        with pytest.raises(TypeError):
            dtarr + dt64vals
        with pytest.raises(TypeError):
            dt64vals + dtarr 
Example 11
Project: bayeslite   Author: probcomp   File: test_pretty.py    License: Apache License 2.0 6 votes vote down vote up
def test_pretty_unicomb():
    pytest.xfail('pp_list counts code points, not grapheme clusters.')
    labels = ['name', 'age', 'favourite food']
    table = [
        ['Spot', 3, 'kibble'],
        ['Skruffles', 2, 'kibble'],
        ['Zorb', 2, 'zorblaxian kibble'],
        ['Zörb', 87, 'zørblaχian ﻛبﻞ'],
        [u'Zörb', 42, u'zörblǎxïǎn kïbble'],
        ['Zörb', 87, 'zørblaχian ﻛِبّﻞ'],
    ]
    out = StringIO.StringIO()
    pretty.pp_list(out, table, labels)
    assert out.getvalue() == \
        u'     name | age |    favourite food\n' \
        u'----------+-----+------------------\n' \
        u'     Spot |   3 |            kibble\n' \
        u'Skruffles |   2 |            kibble\n' \
        u'     Zorb |   2 | zorblaxian kibble\n' \
        u'     Zörb |  42 | zörblǎxïǎn kïbble\n' \
        u'     Zörb |  87 |    zørblaxian ﻛِبّﻞ\n' 
Example 12
Project: filesystem_spec   Author: intake   File: test_local.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_compressions(fmt, mode, tmpdir):
    if fmt == "zip" and sys.version_info < (3, 6):
        pytest.xfail("zip compression requires python3.6 or higher")

    tmpdir = str(tmpdir)
    fn = os.path.join(tmpdir, ".tmp.getsize")
    fs = LocalFileSystem()
    f = OpenFile(fs, fn, compression=fmt, mode="wb")
    data = b"Long line of readily compressible text"
    with f as fo:
        fo.write(data)
    if fmt is None:
        assert fs.size(fn) == len(data)
    else:
        assert fs.size(fn) != len(data)

    f = OpenFile(fs, fn, compression=fmt, mode=mode)
    with f as fo:
        if mode == "rb":
            assert fo.read() == data
        else:
            assert fo.read() == data.decode() 
Example 13
Project: filesystem_spec   Author: intake   File: test_local.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_glob_weird_characters(tmpdir, sep, chars):
    tmpdir = make_path_posix(str(tmpdir))

    subdir = tmpdir + sep + "test" + chars + "x"
    try:
        os.makedirs(subdir, exist_ok=True)
    except OSError as e:
        if WIN and "label syntax" in str(e):
            pytest.xfail("Illegal windows directory name")
        else:
            raise
    with open(subdir + sep + "tmp", "w") as f:
        f.write("hi")

    out = LocalFileSystem().glob(subdir + sep + "*")
    assert len(out) == 1
    assert "/" in out[0]
    assert "tmp" in out[0] 
Example 14
Project: vidgear   Author: abhiTronix   File: test_netgear.py    License: Apache License 2.0 6 votes vote down vote up
def test_client_reliablity(options):
    """
    Testing validation function of WebGear API
    """
    client = None
    try:
        # define params
        client = NetGear(
            pattern=1, port=5554, receive_mode=True, logging=True, **options
        )
        # get data without any connection
        frame_client = client.recv()
        # check for frame
        if frame_client is None:
            raise RuntimeError
    except Exception as e:
        if isinstance(e, (RuntimeError)):
            pytest.xfail("Reconnection ran successfully.")
        else:
            logger.exception(str(e))
    finally:
        # clean resources
        if not (client is None):
            client.close() 
Example 15
Project: python-netsurv   Author: sofia-netsurv   File: pytester.py    License: MIT License 6 votes vote down vote up
def spawn(self, cmd, expect_timeout=10.0):
        """Run a command using pexpect.

        The pexpect child is returned.

        """
        pexpect = pytest.importorskip("pexpect", "3.0")
        if hasattr(sys, "pypy_version_info") and "64" in platform.machine():
            pytest.skip("pypy-64 bit not supported")
        if sys.platform.startswith("freebsd"):
            pytest.xfail("pexpect does not work reliably on freebsd")
        logfile = self.tmpdir.join("spawn.out").open("wb")

        # Do not load user config.
        env = os.environ.copy()
        env.update(self._env_run_update)

        child = pexpect.spawn(cmd, logfile=logfile, env=env)
        self.request.addfinalizer(logfile.close)
        child.timeout = expect_timeout
        return child 
Example 16
Project: python-netsurv   Author: sofia-netsurv   File: skipping.py    License: MIT License 6 votes vote down vote up
def pytest_addoption(parser):
    group = parser.getgroup("general")
    group.addoption(
        "--runxfail",
        action="store_true",
        dest="runxfail",
        default=False,
        help="report the results of xfail tests as if they were not marked",
    )

    parser.addini(
        "xfail_strict",
        "default for the strict parameter of xfail "
        "markers when not given explicitly (default: False)",
        default=False,
        type="bool",
    ) 
Example 17
Project: python-netsurv   Author: sofia-netsurv   File: skipping.py    License: MIT License 6 votes vote down vote up
def pytest_runtest_setup(item):
    # Check if skip or skipif are specified as pytest marks
    item._skipped_by_mark = False
    eval_skipif = MarkEvaluator(item, "skipif")
    if eval_skipif.istrue():
        item._skipped_by_mark = True
        skip(eval_skipif.getexplanation())

    for skip_info in item.iter_markers(name="skip"):
        item._skipped_by_mark = True
        if "reason" in skip_info.kwargs:
            skip(skip_info.kwargs["reason"])
        elif skip_info.args:
            skip(skip_info.args[0])
        else:
            skip("unconditional skip")

    item._evalxfail = MarkEvaluator(item, "xfail")
    check_xfail_no_run(item) 
Example 18
Project: python-netsurv   Author: sofia-netsurv   File: pytester.py    License: MIT License 6 votes vote down vote up
def spawn(self, cmd, expect_timeout=10.0):
        """Run a command using pexpect.

        The pexpect child is returned.

        """
        pexpect = pytest.importorskip("pexpect", "3.0")
        if hasattr(sys, "pypy_version_info") and "64" in platform.machine():
            pytest.skip("pypy-64 bit not supported")
        if sys.platform.startswith("freebsd"):
            pytest.xfail("pexpect does not work reliably on freebsd")
        logfile = self.tmpdir.join("spawn.out").open("wb")

        # Do not load user config.
        env = os.environ.copy()
        env.update(self._env_run_update)

        child = pexpect.spawn(cmd, logfile=logfile, env=env)
        self.request.addfinalizer(logfile.close)
        child.timeout = expect_timeout
        return child 
Example 19
Project: python-netsurv   Author: sofia-netsurv   File: skipping.py    License: MIT License 6 votes vote down vote up
def pytest_addoption(parser):
    group = parser.getgroup("general")
    group.addoption(
        "--runxfail",
        action="store_true",
        dest="runxfail",
        default=False,
        help="report the results of xfail tests as if they were not marked",
    )

    parser.addini(
        "xfail_strict",
        "default for the strict parameter of xfail "
        "markers when not given explicitly (default: False)",
        default=False,
        type="bool",
    ) 
Example 20
Project: python-netsurv   Author: sofia-netsurv   File: skipping.py    License: MIT License 6 votes vote down vote up
def pytest_runtest_setup(item):
    # Check if skip or skipif are specified as pytest marks
    item._skipped_by_mark = False
    eval_skipif = MarkEvaluator(item, "skipif")
    if eval_skipif.istrue():
        item._skipped_by_mark = True
        skip(eval_skipif.getexplanation())

    for skip_info in item.iter_markers(name="skip"):
        item._skipped_by_mark = True
        if "reason" in skip_info.kwargs:
            skip(skip_info.kwargs["reason"])
        elif skip_info.args:
            skip(skip_info.args[0])
        else:
            skip("unconditional skip")

    item._evalxfail = MarkEvaluator(item, "xfail")
    check_xfail_no_run(item) 
Example 21
Project: vnpy_crypto   Author: birforce   File: test_coercion.py    License: MIT License 6 votes vote down vote up
def test_setitem_series_bool(self, val, exp_dtype):
        obj = pd.Series([True, False, True, False])
        assert obj.dtype == np.bool

        if exp_dtype is np.int64:
            exp = pd.Series([True, True, True, False])
            self._assert_setitem_series_conversion(obj, val, exp, np.bool)
            pytest.xfail("TODO_GH12747 The result must be int")
        elif exp_dtype is np.float64:
            exp = pd.Series([True, True, True, False])
            self._assert_setitem_series_conversion(obj, val, exp, np.bool)
            pytest.xfail("TODO_GH12747 The result must be float")
        elif exp_dtype is np.complex128:
            exp = pd.Series([True, True, True, False])
            self._assert_setitem_series_conversion(obj, val, exp, np.bool)
            pytest.xfail("TODO_GH12747 The result must be complex")

        exp = pd.Series([True, val, True, False])
        self._assert_setitem_series_conversion(obj, val, exp, exp_dtype) 
Example 22
Project: vnpy_crypto   Author: birforce   File: test_coercion.py    License: MIT License 6 votes vote down vote up
def test_where_series_datetime64(self, fill_val, exp_dtype):
        obj = pd.Series([pd.Timestamp('2011-01-01'),
                         pd.Timestamp('2011-01-02'),
                         pd.Timestamp('2011-01-03'),
                         pd.Timestamp('2011-01-04')])
        assert obj.dtype == 'datetime64[ns]'
        cond = pd.Series([True, False, True, False])

        exp = pd.Series([pd.Timestamp('2011-01-01'), fill_val,
                         pd.Timestamp('2011-01-03'), fill_val])
        self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

        values = pd.Series(pd.date_range(fill_val, periods=4))
        if fill_val.tz:
            exp = pd.Series([pd.Timestamp('2011-01-01'),
                             pd.Timestamp('2012-01-02 05:00'),
                             pd.Timestamp('2011-01-03'),
                             pd.Timestamp('2012-01-04 05:00')])
            self._assert_where_conversion(obj, cond, values, exp,
                                          'datetime64[ns]')
            pytest.xfail("ToDo: do not coerce to UTC, must be object")

        exp = pd.Series([pd.Timestamp('2011-01-01'), values[1],
                         pd.Timestamp('2011-01-03'), values[3]])
        self._assert_where_conversion(obj, cond, values, exp, exp_dtype) 
Example 23
Project: vnpy_crypto   Author: birforce   File: test_coercion.py    License: MIT License 6 votes vote down vote up
def test_replace_series_datetime_tz(self):
        how = 'series'
        from_key = 'datetime64[ns, US/Eastern]'
        to_key = 'timedelta64[ns]'

        index = pd.Index([3, 4], name='xxx')
        obj = pd.Series(self.rep[from_key], index=index, name='yyy')
        assert obj.dtype == from_key

        if how == 'dict':
            replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
        elif how == 'series':
            replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
        else:
            raise ValueError

        result = obj.replace(replacer)
        exp = pd.Series(self.rep[to_key], index=index, name='yyy')
        assert exp.dtype == to_key

        tm.assert_series_equal(result, exp)

    # TODO(jreback) commented out to only have a single xfail printed 
Example 24
Project: vnpy_crypto   Author: birforce   File: test_transform.py    License: MIT License 6 votes vote down vote up
def test_transform_numeric_ret(cols, exp, comp_func, agg_func):
    if agg_func == 'size' and isinstance(cols, list):
        pytest.xfail("'size' transformation not supported with "
                     "NDFrameGroupy")

    # GH 19200
    df = pd.DataFrame(
        {'a': pd.date_range('2018-01-01', periods=3),
         'b': range(3),
         'c': range(7, 10)})

    result = df.groupby('b')[cols].transform(agg_func)

    if agg_func == 'rank':
        exp = exp.astype('float')

    comp_func(result, exp) 
Example 25
Project: ngraph-python   Author: NervanaSystems   File: test_linear_layer.py    License: Apache License 2.0 6 votes vote down vote up
def test_linear_ones(input_size, input_placeholder, output_size):
    # basic sanity check with all ones on the inputs and weights, check that
    # each row in output is the sum of the weights for that output this check
    # will confirm that the correct number of operations is being run
    x = np.ones(input_placeholder.axes.lengths)
    layer = Linear(nout=output_size, init=UniformInit(1.0, 1.0))

    with ExecutorFactory() as ex:
        if ex.transformer.transformer_name == 'hetr':
            pytest.xfail("hetr fork-safe issue on mac")
        out = layer(input_placeholder)
        comp = ex.executor([out, layer.W], input_placeholder)
        output_values, w = comp(x)

    ng.testing.assert_allclose(
        np.ones(out.axes.lengths) * input_size,
        output_values,
        atol=0.0, rtol=0.0
    ) 
Example 26
Project: ngraph-python   Author: NervanaSystems   File: test_linear_layer.py    License: Apache License 2.0 6 votes vote down vote up
def test_linear_keep_axes_ones(batch_axis, input_size, input_placeholder, output_size,
                               transformer_factory):
    # basic sanity check with all ones on the inputs and weights, check that
    # each row in output is the sum of the weights for that output this check
    # will confirm that the correct number of operations is being run
    x = np.ones(input_placeholder.axes.lengths)
    layer = Linear(nout=output_size, keep_axes=[], init=UniformInit(1.0, 1.0))

    with ExecutorFactory() as ex:
        if ex.transformer.transformer_name == 'hetr':
            pytest.xfail("hetr fork-safe issue on mac")
        out = layer(input_placeholder)
        comp = ex.executor([out, layer.W], input_placeholder)
        output_values, w = comp(x)

    assert np.allclose(
        np.ones(out.axes.lengths) * input_size * batch_axis.length,
        output_values,
        atol=0.0, rtol=0.0
    ) 
Example 27
Project: ngraph-python   Author: NervanaSystems   File: test_linear_layer.py    License: Apache License 2.0 6 votes vote down vote up
def test_linear_keep_batch_axes_ones(batch_axis, input_size, input_placeholder, output_size,
                                     transformer_factory):
    # basic sanity check with all ones on the inputs and weights, check that
    # each row in output is the sum of the weights for that output this check
    # will confirm that the correct number of operations is being run
    x = np.ones(input_placeholder.axes.lengths)
    layer = Linear(nout=output_size, keep_axes=[batch_axis], init=UniformInit(1.0, 1.0))

    with ExecutorFactory() as ex:
        if ex.transformer.transformer_name == 'hetr':
            pytest.xfail("hetr fork-safe issue on mac")
        out = layer(input_placeholder)
        comp = ex.executor([out, layer.W], input_placeholder)
        output_values, w = comp(x)

    assert np.allclose(
        np.ones(out.axes.lengths) * input_size,
        output_values,
        atol=0.0, rtol=0.0
    ) 
Example 28
Project: ngraph-python   Author: NervanaSystems   File: test_hetr_integration.py    License: Apache License 2.0 6 votes vote down vote up
def test_multi_computations(hetr_device):
    if hetr_device == 'gpu':
        pytest.xfail("enable after gpu exgraph")
    axes_x = ng.make_axes([ax_A, ax_B])
    x = ng.placeholder(axes=axes_x)
    y = ng.placeholder(())
    with ng.metadata(device_id=('0', '1'), parallel=ax_A):
        f = x ** 2
        out = y - ng.mean(f, out_axes=())

    np_x = np.random.randint(10, size=axes_x.lengths)
    np_y = np.random.randint(10)
    with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as t:
        comp = t.computation(out, x, y)
        another_comp = t.computation(f, x)

        res_comp = comp(np_x, np_y)
        res_another_comp = another_comp(np_x)
        ref_comp = np_y - np.mean(np_x**2)
        np.testing.assert_array_equal(res_comp, ref_comp)
        np.testing.assert_array_equal(res_another_comp, np_x**2) 
Example 29
Project: ngraph-python   Author: NervanaSystems   File: test_hetr_integration.py    License: Apache License 2.0 6 votes vote down vote up
def test_repeat_computation(hetr_device, config):
    if hetr_device == 'gpu':
        pytest.xfail("enable after gpu exgraph")
    device_id = config['device_id']
    axes = config['axes']
    parallel_axis = config['parallel_axis']

    with ng.metadata(device=hetr_device):
        x = ng.placeholder(axes=axes)
        with ng.metadata(device_id=device_id, parallel=parallel_axis):
            x_plus_one = x + 1

        np_x = np.random.randint(100, size=axes.lengths)
        with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as transformer:
            comp = transformer.computation(x_plus_one, x)
            comp2 = transformer.computation(x_plus_one, x)

            res = comp(np_x)
            np.testing.assert_array_equal(res, np_x + 1)

            res2 = comp2(np_x)
            np.testing.assert_array_equal(res2, np_x + 1) 
Example 30
Project: ngraph-python   Author: NervanaSystems   File: test_hetr_integration.py    License: Apache License 2.0 6 votes vote down vote up
def test_reduce_scalar(hetr_device):
    """
    A scalar is produced by sum() on each worker
    in this case, should be mean reduced before being returned
    """
    if hetr_device == 'gpu':
        pytest.xfail("gather/reduce work-around for gpus does not choose between mean or sum,\
        it uses only the value on the first device and ignores the values on other devices")

    N = ng.make_axis(length=8, name='batch')
    x = ng.placeholder(axes=[N])
    with ng.metadata(device=hetr_device, device_id=('0', '1'), parallel=N):
        out = ng.sum(x)

    np_x = np.random.randint(100, size=[N.length])
    with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as transformer:
        computation = transformer.computation(out, x)
        res = computation(np_x)

        # gather returns one element per worker
        np.testing.assert_array_equal(res, np.sum(np_x) / 2.)