Python numpy.tril() Examples
The following are 30
code examples of numpy.tril().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
numpy
, or try the search function
.
Example #1
Source File: test_matfuncs.py From GraphicDesignPatternByPython with MIT License | 6 votes |
def test_al_mohy_higham_2012_experiment_1(self): # Fractional powers of a tricky upper triangular matrix. A = _get_al_mohy_higham_2012_experiment_1() # Test remainder matrix power. A_funm_sqrt, info = funm(A, np.sqrt, disp=False) A_sqrtm, info = sqrtm(A, disp=False) A_rem_power = _matfuncs_inv_ssq._remainder_matrix_power(A, 0.5) A_power = fractional_matrix_power(A, 0.5) assert_array_equal(A_rem_power, A_power) assert_allclose(A_sqrtm, A_power) assert_allclose(A_sqrtm, A_funm_sqrt) # Test more fractional powers. for p in (1/2, 5/3): A_power = fractional_matrix_power(A, p) A_round_trip = fractional_matrix_power(A_power, 1/p) assert_allclose(A_round_trip, A, rtol=1e-2) assert_allclose(np.tril(A_round_trip, 1), np.tril(A, 1))
Example #2
Source File: test_twodim_base.py From recruit with Apache License 2.0 | 6 votes |
def test_tril_triu_ndim3(): for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']: a = np.array([ [[1, 1], [1, 1]], [[1, 1], [1, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_tril_desired = np.array([ [[1, 0], [1, 1]], [[1, 0], [1, 0]], [[1, 0], [0, 0]], ], dtype=dtype) a_triu_desired = np.array([ [[1, 1], [0, 1]], [[1, 1], [0, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_triu_observed = np.triu(a) a_tril_observed = np.tril(a) assert_array_equal(a_triu_observed, a_triu_desired) assert_array_equal(a_tril_observed, a_tril_desired) assert_equal(a_triu_observed.dtype, a.dtype) assert_equal(a_tril_observed.dtype, a.dtype)
Example #3
Source File: test_twodim_base.py From recruit with Apache License 2.0 | 6 votes |
def test_tril_triu_dtype(): # Issue 4916 # tril and triu should return the same dtype as input for c in np.typecodes['All']: if c == 'V': continue arr = np.zeros((3, 3), dtype=c) assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) # check special cases arr = np.array([['2001-01-01T12:00', '2002-02-03T13:56'], ['2004-01-01T12:00', '2003-01-03T13:45']], dtype='datetime64') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) arr = np.zeros((3,3), dtype='f4,f4') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype)
Example #4
Source File: test_linalg_execute.py From mars with Apache License 2.0 | 6 votes |
def testSolveSymPos(self): import scipy.linalg np.random.seed(1) data = np.random.randint(1, 10, (20, 20)) data_l = np.tril(data) data1 = data_l.dot(data_l.T) data2 = np.random.randint(1, 10, (20, )) A = tensor(data1, chunk_size=5) b = tensor(data2, chunk_size=5) x = solve(A, b, sym_pos=True) res = self.executor.execute_tensor(x, concat=True)[0] np.testing.assert_allclose(res, scipy.linalg.solve(data1, data2)) res = self.executor.execute_tensor(A.dot(x), concat=True)[0] np.testing.assert_allclose(res, data2)
Example #5
Source File: bottom_up.py From Dispersion-based-Clustering with MIT License | 6 votes |
def select_merge_data(self, u_feas, label, label_to_images, ratio_n, dists): dists.add_(torch.tril(100000 * torch.ones(len(u_feas), len(u_feas))))#blocking the triangle cnt = torch.FloatTensor([len(label_to_images[label[idx]]) for idx in range(len(u_feas))]) dists += ratio_n * (cnt.view(1, len(cnt)) + cnt.view(len(cnt), 1)) # dist += |A|+|B| for idx in range(len(u_feas)): for j in range(idx + 1, len(u_feas)): if label[idx] == label[j]: dists[idx, j] = 100000 # set the distance within the same cluster dists = dists.numpy() ind = np.unravel_index(np.argsort(dists, axis=None), dists.shape) # with axis=None all numbers are sorted and unravel_index transforms the sorted index into ind for each dimension idx1 = ind[0] # the first dimension index idx2 = ind[1] # the second dimension index return idx1, idx2
Example #6
Source File: bottom_up.py From Dispersion-based-Clustering with MIT License | 6 votes |
def select_merge_data_v2(self, u_feas, labels, linkages): linkages+=(np.tril(100000 * np.ones((len(u_feas), len(u_feas))))) # blocking the triangle print('Linkage adding') for idx in range(len(u_feas)): for j in range(idx + 1, len(u_feas)): if labels[idx] == labels[j]: linkages[idx, j] = 100000 # set the distance within the same cluster ind = np.unravel_index(np.argsort(linkages, axis=None), linkages.shape) # with axis=None all numbers are sorted and unravel_index transforms the sorted index into ind for each dimension idx1 = ind[0] # the first cluster index idx2 = ind[1] # the second cluster index print('Linkage add finished') return idx1, idx2 #after
Example #7
Source File: bottom_up.py From Dispersion-based-Clustering with MIT License | 6 votes |
def select_merge_data(self, u_feas, label, label_to_images, ratio_n, dists): dists.add_(torch.tril(100000 * torch.ones(len(u_feas), len(u_feas)))) cnt = torch.FloatTensor([len(label_to_images[label[idx]]) for idx in range(len(u_feas))]) dists += ratio_n * (cnt.view(1, len(cnt)) + cnt.view(len(cnt), 1)) for idx in range(len(u_feas)): for j in range(idx + 1, len(u_feas)): if label[idx] == label[j]: dists[idx, j] = 100000 dists = dists.numpy() ind = np.unravel_index(np.argsort(dists, axis=None), dists.shape) idx1 = ind[0] idx2 = ind[1] return idx1, idx2
Example #8
Source File: test_twodim_base.py From lambda-packs with MIT License | 6 votes |
def test_tril_triu_ndim3(): for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']: a = np.array([ [[1, 1], [1, 1]], [[1, 1], [1, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_tril_desired = np.array([ [[1, 0], [1, 1]], [[1, 0], [1, 0]], [[1, 0], [0, 0]], ], dtype=dtype) a_triu_desired = np.array([ [[1, 1], [0, 1]], [[1, 1], [0, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_triu_observed = np.triu(a) a_tril_observed = np.tril(a) yield assert_array_equal, a_triu_observed, a_triu_desired yield assert_array_equal, a_tril_observed, a_tril_desired yield assert_equal, a_triu_observed.dtype, a.dtype yield assert_equal, a_tril_observed.dtype, a.dtype
Example #9
Source File: test_twodim_base.py From lambda-packs with MIT License | 6 votes |
def test_tril_triu_dtype(): # Issue 4916 # tril and triu should return the same dtype as input for c in np.typecodes['All']: if c == 'V': continue arr = np.zeros((3, 3), dtype=c) assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) # check special cases arr = np.array([['2001-01-01T12:00', '2002-02-03T13:56'], ['2004-01-01T12:00', '2003-01-03T13:45']], dtype='datetime64') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) arr = np.zeros((3,3), dtype='f4,f4') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype)
Example #10
Source File: test_twodim_base.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def test_tril_triu_ndim3(): for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']: a = np.array([ [[1, 1], [1, 1]], [[1, 1], [1, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_tril_desired = np.array([ [[1, 0], [1, 1]], [[1, 0], [1, 0]], [[1, 0], [0, 0]], ], dtype=dtype) a_triu_desired = np.array([ [[1, 1], [0, 1]], [[1, 1], [0, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_triu_observed = np.triu(a) a_tril_observed = np.tril(a) yield assert_array_equal, a_triu_observed, a_triu_desired yield assert_array_equal, a_tril_observed, a_tril_desired yield assert_equal, a_triu_observed.dtype, a.dtype yield assert_equal, a_tril_observed.dtype, a.dtype
Example #11
Source File: test_twodim_base.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def test_tril_triu_dtype(): # Issue 4916 # tril and triu should return the same dtype as input for c in np.typecodes['All']: if c == 'V': continue arr = np.zeros((3, 3), dtype=c) assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) # check special cases arr = np.array([['2001-01-01T12:00', '2002-02-03T13:56'], ['2004-01-01T12:00', '2003-01-03T13:45']], dtype='datetime64') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) arr = np.zeros((3,3), dtype='f4,f4') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype)
Example #12
Source File: test_twodim_base.py From vnpy_crypto with MIT License | 6 votes |
def test_tril_triu_ndim3(): for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']: a = np.array([ [[1, 1], [1, 1]], [[1, 1], [1, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_tril_desired = np.array([ [[1, 0], [1, 1]], [[1, 0], [1, 0]], [[1, 0], [0, 0]], ], dtype=dtype) a_triu_desired = np.array([ [[1, 1], [0, 1]], [[1, 1], [0, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_triu_observed = np.triu(a) a_tril_observed = np.tril(a) yield assert_array_equal, a_triu_observed, a_triu_desired yield assert_array_equal, a_tril_observed, a_tril_desired yield assert_equal, a_triu_observed.dtype, a.dtype yield assert_equal, a_tril_observed.dtype, a.dtype
Example #13
Source File: test_twodim_base.py From vnpy_crypto with MIT License | 6 votes |
def test_tril_triu_dtype(): # Issue 4916 # tril and triu should return the same dtype as input for c in np.typecodes['All']: if c == 'V': continue arr = np.zeros((3, 3), dtype=c) assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) # check special cases arr = np.array([['2001-01-01T12:00', '2002-02-03T13:56'], ['2004-01-01T12:00', '2003-01-03T13:45']], dtype='datetime64') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) arr = np.zeros((3,3), dtype='f4,f4') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype)
Example #14
Source File: _binary_codes.py From OpenFermion with Apache License 2.0 | 6 votes |
def parity_code(n_modes): """ The parity transform (arXiv:1208.5986) as binary code. This code is very similar to the Jordan-Wigner transform, but with long update strings instead of parity strings. It does not save qubits: n_qubits = n_modes. Args: n_modes (int): number of modes Returns (BinaryCode): The parity transform BinaryCode """ dec_mtx = numpy.reshape(([1] + [0] * (n_modes - 1)) + ([1, 1] + (n_modes - 1) * [0]) * (n_modes - 2) + [1, 1], (n_modes, n_modes)) enc_mtx = numpy.tril(numpy.ones((n_modes, n_modes), dtype=int)) return BinaryCode(enc_mtx, linearize_decoder(dec_mtx))
Example #15
Source File: test_blas.py From Computable with MIT License | 6 votes |
def test_syrk(self): for f in _get_func('syrk'): c = f(a=self.a, alpha=1.) assert_array_almost_equal(np.triu(c), np.triu(self.t)) c = f(a=self.a, alpha=1., lower=1) assert_array_almost_equal(np.tril(c), np.tril(self.t)) c0 = np.ones(self.t.shape) c = f(a=self.a, alpha=1., beta=1., c=c0) assert_array_almost_equal(np.triu(c), np.triu(self.t+c0)) c = f(a=self.a, alpha=1., trans=1) assert_array_almost_equal(np.triu(c), np.triu(self.tt)) #prints '0-th dimension must be fixed to 3 but got 5', FIXME: suppress? # FIXME: how to catch the _fblas.error?
Example #16
Source File: test_blas.py From Computable with MIT License | 6 votes |
def test_syr2k(self): for f in _get_func('syr2k'): c = f(a=self.a, b=self.b, alpha=1.) assert_array_almost_equal(np.triu(c), np.triu(self.t)) c = f(a=self.a, b=self.b, alpha=1., lower=1) assert_array_almost_equal(np.tril(c), np.tril(self.t)) c0 = np.ones(self.t.shape) c = f(a=self.a, b=self.b, alpha=1., beta=1., c=c0) assert_array_almost_equal(np.triu(c), np.triu(self.t+c0)) c = f(a=self.a, b=self.b, alpha=1., trans=1) assert_array_almost_equal(np.triu(c), np.triu(self.tt)) #prints '0-th dimension must be fixed to 3 but got 5', FIXME: suppress?
Example #17
Source File: test_matfuncs.py From Computable with MIT License | 6 votes |
def test_al_mohy_higham_2012_experiment_1(self): # Fractional powers of a tricky upper triangular matrix. A = _get_al_mohy_higham_2012_experiment_1() # Test remainder matrix power. A_funm_sqrt, info = funm(A, np.sqrt, disp=False) A_sqrtm, info = sqrtm(A, disp=False) A_rem_power = _matfuncs_inv_ssq._remainder_matrix_power(A, 0.5) A_power = fractional_matrix_power(A, 0.5) assert_array_equal(A_rem_power, A_power) assert_allclose(A_sqrtm, A_power) assert_allclose(A_sqrtm, A_funm_sqrt) # Test more fractional powers. for p in (1/2, 5/3): A_power = fractional_matrix_power(A, p) A_round_trip = fractional_matrix_power(A_power, 1/p) assert_allclose(A_round_trip, A, rtol=1e-2) assert_allclose(np.tril(A_round_trip, 1), np.tril(A, 1))
Example #18
Source File: test_serialization_fileview.py From dimod with Apache License 2.0 | 6 votes |
def test_readinto_partial_sliding17(self, name, BQM, version): bqm = BQM(np.tril(np.arange(25).reshape((5, 5))), 'BINARY') bqm.offset = -6 fv = FileView(bqm, version=version) buff = fv.readall() for pos in range(fv.quadratic_end): self.assertEqual(pos, fv.seek(pos)) subbuff = bytearray(17) num_read = fv.readinto(subbuff) self.assertGreater(num_read, 0) self.assertEqual(subbuff[:num_read], buff[pos:pos+num_read]) # Ocean only supports 64bit python
Example #19
Source File: test_twodim_base.py From Mastering-Elasticsearch-7.0 with MIT License | 6 votes |
def test_tril_triu_ndim3(): for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']: a = np.array([ [[1, 1], [1, 1]], [[1, 1], [1, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_tril_desired = np.array([ [[1, 0], [1, 1]], [[1, 0], [1, 0]], [[1, 0], [0, 0]], ], dtype=dtype) a_triu_desired = np.array([ [[1, 1], [0, 1]], [[1, 1], [0, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_triu_observed = np.triu(a) a_tril_observed = np.tril(a) assert_array_equal(a_triu_observed, a_triu_desired) assert_array_equal(a_tril_observed, a_tril_desired) assert_equal(a_triu_observed.dtype, a.dtype) assert_equal(a_tril_observed.dtype, a.dtype)
Example #20
Source File: test_twodim_base.py From Mastering-Elasticsearch-7.0 with MIT License | 6 votes |
def test_tril_triu_dtype(): # Issue 4916 # tril and triu should return the same dtype as input for c in np.typecodes['All']: if c == 'V': continue arr = np.zeros((3, 3), dtype=c) assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) # check special cases arr = np.array([['2001-01-01T12:00', '2002-02-03T13:56'], ['2004-01-01T12:00', '2003-01-03T13:45']], dtype='datetime64') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) arr = np.zeros((3,3), dtype='f4,f4') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype)
Example #21
Source File: array_ops.py From trax with Apache License 2.0 | 6 votes |
def tril(m, k=0): # pylint: disable=missing-docstring m = asarray(m).data m_shape = m.shape.as_list() if len(m_shape) < 2: raise ValueError('Argument to tril must have rank at least 2') if m_shape[-1] is None or m_shape[-2] is None: raise ValueError('Currently, the last two dimensions of the input array ' 'need to be known.') z = tf.constant(0, m.dtype) mask = tri(*m_shape[-2:], k=k, dtype=bool) return utils.tensor_to_ndarray( tf.where(tf.broadcast_to(mask, tf.shape(m)), m, z))
Example #22
Source File: attention.py From trax with Apache License 2.0 | 6 votes |
def forward(self, inputs): q, k, v = inputs if self._mode == 'predict': self.state = _fast_inference_update_state(inputs, self.state) (k, v, mask, _) = self.state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.backend_name() == 'jax': mask = jnp.tril( jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril( np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res = DotProductAttention( q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) return res
Example #23
Source File: test_twodim_base.py From GraphicDesignPatternByPython with MIT License | 6 votes |
def test_tril_triu_ndim3(): for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']: a = np.array([ [[1, 1], [1, 1]], [[1, 1], [1, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_tril_desired = np.array([ [[1, 0], [1, 1]], [[1, 0], [1, 0]], [[1, 0], [0, 0]], ], dtype=dtype) a_triu_desired = np.array([ [[1, 1], [0, 1]], [[1, 1], [0, 0]], [[1, 1], [0, 0]], ], dtype=dtype) a_triu_observed = np.triu(a) a_tril_observed = np.tril(a) assert_array_equal(a_triu_observed, a_triu_desired) assert_array_equal(a_tril_observed, a_tril_desired) assert_equal(a_triu_observed.dtype, a.dtype) assert_equal(a_tril_observed.dtype, a.dtype)
Example #24
Source File: test_twodim_base.py From GraphicDesignPatternByPython with MIT License | 6 votes |
def test_tril_triu_dtype(): # Issue 4916 # tril and triu should return the same dtype as input for c in np.typecodes['All']: if c == 'V': continue arr = np.zeros((3, 3), dtype=c) assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) # check special cases arr = np.array([['2001-01-01T12:00', '2002-02-03T13:56'], ['2004-01-01T12:00', '2003-01-03T13:45']], dtype='datetime64') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype) arr = np.zeros((3,3), dtype='f4,f4') assert_equal(np.triu(arr).dtype, arr.dtype) assert_equal(np.tril(arr).dtype, arr.dtype)
Example #25
Source File: test_blas.py From GraphicDesignPatternByPython with MIT License | 6 votes |
def test_syrk(self): for f in _get_func('syrk'): c = f(a=self.a, alpha=1.) assert_array_almost_equal(np.triu(c), np.triu(self.t)) c = f(a=self.a, alpha=1., lower=1) assert_array_almost_equal(np.tril(c), np.tril(self.t)) c0 = np.ones(self.t.shape) c = f(a=self.a, alpha=1., beta=1., c=c0) assert_array_almost_equal(np.triu(c), np.triu(self.t+c0)) c = f(a=self.a, alpha=1., trans=1) assert_array_almost_equal(np.triu(c), np.triu(self.tt)) # prints '0-th dimension must be fixed to 3 but got 5', # FIXME: suppress? # FIXME: how to catch the _fblas.error?
Example #26
Source File: test_blas.py From GraphicDesignPatternByPython with MIT License | 6 votes |
def test_syr2k(self): for f in _get_func('syr2k'): c = f(a=self.a, b=self.b, alpha=1.) assert_array_almost_equal(np.triu(c), np.triu(self.t)) c = f(a=self.a, b=self.b, alpha=1., lower=1) assert_array_almost_equal(np.tril(c), np.tril(self.t)) c0 = np.ones(self.t.shape) c = f(a=self.a, b=self.b, alpha=1., beta=1., c=c0) assert_array_almost_equal(np.triu(c), np.triu(self.t+c0)) c = f(a=self.a, b=self.b, alpha=1., trans=1) assert_array_almost_equal(np.triu(c), np.triu(self.tt)) # prints '0-th dimension must be fixed to 3 but got 5', FIXME: suppress?
Example #27
Source File: test_pardiso.py From simnibs with GNU General Public License v3.0 | 6 votes |
def create_matrix(dim, alpha=0.95, smallest_coef=0.1, largest_coef=.9): ''' Based o scikit-learn make_sparse_spd_matrix''' chol = -np.eye(dim) aux = np.random.rand(dim, dim) aux[aux < alpha] = 0 aux[aux > alpha] = (smallest_coef + (largest_coef - smallest_coef) * np.random.rand(np.sum(aux > alpha))) aux = np.tril(aux, k=-1) # Permute the lines: we don't want to have asymmetries in the final # SPD matrix permutation = np.random.permutation(dim) aux = aux[permutation].T[permutation] chol += aux A = sp.csc_matrix(np.dot(chol.T, chol)) x = np.random.rand(dim) b = A.dot(x) return A,b,x
Example #28
Source File: model.py From qrn with MIT License | 5 votes |
def __init__(self, batch_size, mem_size, hidden_size): self.hidden_size = hidden_size self.mem_size = mem_size self.batch_size = batch_size N, M, d = batch_size, mem_size, hidden_size self.L = np.tril(np.ones([M, M], dtype='float32')) self.sL = np.tril(np.ones([M, M], dtype='float32'), k=-1)
Example #29
Source File: model.py From qrn with MIT License | 5 votes |
def __init__(self, batch_size, mem_size, hidden_size): self.hidden_size = hidden_size self.mem_size = mem_size self.batch_size = batch_size N, M, d = batch_size, mem_size, hidden_size self.L = np.tril(np.ones([M, M], dtype='float32')) self.sL = np.tril(np.ones([M, M], dtype='float32'), k=-1)
Example #30
Source File: cmp_utils.py From yolo_v2 with Apache License 2.0 | 5 votes |
def get_visual_frustum(map_size, shape_like, expand_dims=[0,0]): with tf.name_scope('visual_frustum'): l = np.tril(np.ones(map_size)) ;l = l + l[:,::-1] l = (l == 2).astype(np.float32) for e in expand_dims: l = np.expand_dims(l, axis=e) confs_probs = tf.constant(l, dtype=tf.float32) confs_probs = tf.ones_like(shape_like, dtype=tf.float32) * confs_probs return confs_probs