Python numpy.argwhere() Examples

The following are 30 code examples of numpy.argwhere(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module numpy , or try the search function .
Example #1
Source File: tcpr.py    From libTLDA with MIT License 7 votes vote down vote up
def add_intercept(self, X):
        """Add 1's to data as last features."""
        # Data shape
        N, D = X.shape

        # Check if there's not already an intercept column
        if np.any(np.sum(X, axis=0) == N):

            # Report
            print('Intercept is not the last feature. Swapping..')

            # Find which column contains the intercept
            intercept_index = np.argwhere(np.sum(X, axis=0) == N)

            # Swap intercept to last
            X = X[:, np.setdiff1d(np.arange(D), intercept_index)]

        # Add intercept as last column
        X = np.hstack((X, np.ones((N, 1))))

        # Append column of 1's to data, and increment dimensionality
        return X, D+1 
Example #2
Source File: chacon_tarazona.py    From pytim with GNU General Public License v3.0 6 votes vote down vote up
def _points_next_to_surface(self, surf, modes, pivot):
        """ Searches for points within a distance self.tau from the
            interface.
        """
        pivot_pos = self.cluster_group[pivot].positions
        z_max = np.max(pivot_pos[:, 2])
        z_min = np.min(pivot_pos[:, 2])
        z_max += self.alpha * 2
        z_min -= self.alpha * 2
        positions = self.cluster_group.positions[:]
        # TODO other directions
        z = positions[:, 2]
        condition = np.logical_and(z > z_min, z < z_max)
        candidates = np.argwhere(condition)[:, 0]
        dists = surf.surface_from_modes(positions[candidates], modes)
        dists = dists - z[candidates]
        return candidates[dists * dists < self.tau**2] 
Example #3
Source File: finetune.py    From PSMNet with MIT License 6 votes vote down vote up
def test(imgL,imgR,disp_true):
        model.eval()
        imgL   = Variable(torch.FloatTensor(imgL))
        imgR   = Variable(torch.FloatTensor(imgR))   
        if args.cuda:
            imgL, imgR = imgL.cuda(), imgR.cuda()

        with torch.no_grad():
            output3 = model(imgL,imgR)

        pred_disp = output3.data.cpu()

        #computing 3-px error#
        true_disp = disp_true
        index = np.argwhere(true_disp>0)
        disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(true_disp[index[0][:], index[1][:], index[2][:]]-pred_disp[index[0][:], index[1][:], index[2][:]])
        correct = (disp_true[index[0][:], index[1][:], index[2][:]] < 3)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)      
        torch.cuda.empty_cache()

        return 1-(float(torch.sum(correct))/float(len(index[0]))) 
Example #4
Source File: geometry.py    From phidl with MIT License 6 votes vote down vote up
def _merge_nearby_floating_points(x, tol = 1e-10):
    """ Takes an array `x` and merges any values within the tolerance `tol`
    So if given
    >>> x = [-2, -1, 0, 1.0001, 1.0002, 1.0003, 4, 5, 5.003, 6, 7, 8]
    >>> _merge_nearby_floating_points(x, tol = 1e-3)
    will then return:
    >>> [-2, -1, 0, 1.0001, 1.0001, 1.0001, 4, 5, 5.003, 6, 7, 8] """
    xargsort = np.argsort(x)
    xargunsort = np.argsort(xargsort)
    xsort = x[xargsort]
    xsortthreshold = (np.diff(xsort) < tol)
    xsortthresholdind = np.argwhere(xsortthreshold)

    # Merge nearby floating point values
    for xi in xsortthresholdind:
         xsort[xi+1] = xsort[xi]
    return xsort[xargunsort] 
Example #5
Source File: ImageView.py    From tf-pose with Apache License 2.0 6 votes vote down vote up
def timeIndex(self, slider):
        ## Return the time and frame index indicated by a slider
        if self.image is None:
            return (0,0)
        
        t = slider.value()
        
        xv = self.tVals
        if xv is None:
            ind = int(t)
        else:
            if len(xv) < 2:
                return (0,0)
            totTime = xv[-1] + (xv[-1]-xv[-2])
            inds = np.argwhere(xv < t)
            if len(inds) < 1:
                return (0,t)
            ind = inds[-1,0]
        return ind, t 
Example #6
Source File: ulocal.py    From pyscf with Apache License 2.0 6 votes vote down vote up
def lowdinPop(mol,coeff,ova,enorb,occ):
   print '\nLowdin population for LMOs:'
   nb,nc = coeff.shape
   s12 = sqrtm(ova)
   lcoeff = s12.dot(coeff)
   diff = reduce(numpy.dot,(lcoeff.T,lcoeff)) - numpy.identity(nc)
   print 'diff=',numpy.linalg.norm(diff)
   pthresh = 0.05
   labels = mol.ao_labels(None)
   nelec = 0.0
   for iorb in range(nc):
      vec = lcoeff[:,iorb]**2
      idx = list(numpy.argwhere(vec>pthresh))
      print ' iorb=',iorb,' occ=',occ[iorb],' <i|F|i>=',enorb[iorb]
      for iao in idx:
         print '    iao=',labels[iao],' pop=',vec[iao]
      nelec += occ[iorb]
   print 'nelec=',nelec
   return 0 
Example #7
Source File: test_base_execute.py    From mars with Apache License 2.0 6 votes vote down vote up
def testArgwhereExecution(self):
        x = arange(6, chunk_size=2).reshape(2, 3)
        t = argwhere(x > 1)

        res = self.executor.execute_tensor(t, concat=True)[0]
        expected = np.argwhere(np.arange(6).reshape(2, 3) > 1)

        np.testing.assert_array_equal(res, expected)

        data = np.asfortranarray(np.random.rand(10, 20))
        x = tensor(data, chunk_size=10)

        t = argwhere(x > 0.5)

        res = self.executor.execute_tensor(t, concat=True)[0]
        expected = np.argwhere(data > 0.5)

        np.testing.assert_array_equal(res, expected)
        self.assertTrue(res.flags['F_CONTIGUOUS'])
        self.assertFalse(res.flags['C_CONTIGUOUS']) 
Example #8
Source File: linalg_helper.py    From pyscf with Apache License 2.0 6 votes vote down vote up
def precond(r, e0, x0):
        idx = numpy.argwhere(abs(x0)>.1).ravel()
        #idx = numpy.arange(20)
        m = idx.size
        if m > 2:
            h0 = a[idx][:,idx] - numpy.eye(m)*e0
            h0x0 = x0 / (a.diagonal() - e0)
            h0x0[idx] = numpy.linalg.solve(h0, h0x0[idx])
            h0r = r / (a.diagonal() - e0)
            h0r[idx] = numpy.linalg.solve(h0, r[idx])
            e1 = numpy.dot(x0, h0r) / numpy.dot(x0, h0x0)
            x1 = (r - e1*x0) / (a.diagonal() - e0)
            x1[idx] = numpy.linalg.solve(h0, (r-e1*x0)[idx])
            return x1
        else:
            return r / (a.diagonal() - e0) 
Example #9
Source File: postprocess.py    From argus-tgs-salt with MIT License 6 votes vote down vote up
def find_points(mask, x_shift=0, y_shift=0):
    # Find points where mask change class on edges
    mask = mask > 0
    mask = mask.astype(np.int)
    n = mask.shape[1]
    edges = [mask[:, 0+x_shift], mask[:, -1-x_shift],
             mask[0+y_shift, :], mask[-1-y_shift, :]]
    diffs = [np.diff(edge, n=1) for edge in edges]
    pos = [np.argwhere(diff>0)+1 for diff in diffs]
    neg = [np.argwhere(diff<0)+1 for diff in diffs]
    pos = [[int(x) for x in p] for p in pos]
    neg = [[int(x) for x in n] for n in neg]
    if mask[0, 0] > 0:
        for i in [left, top]:
            pos[i] = [0] + pos[i]
    if mask[-1, 0] > 0:
        pos[bottom] = [0] + pos[bottom]
        neg[left] = [n] + neg[left]
    if mask[0, -1] > 0:
        pos[right] = [0] + pos[right]
        neg[top] = [n] + neg[top]
    if mask[-1, -1] > 0:
        for i in [right, bottom]:
            neg[i] = [n] + neg[i]
    return(pos, neg) 
Example #10
Source File: common_slow.py    From pyscf with Apache License 2.0 6 votes vote down vote up
def format_mask(x):
    """
    Formats a mask into a readable string.
    Args:
        x (ndarray): an array with the mask;

    Returns:
        A readable string with the mask.
    """
    x = numpy.asanyarray(x)
    if len(x) == 0:
        return "(empty)"
    if x.dtype == bool:
        x = numpy.argwhere(x)[:, 0]
    grps = tuple(list(g) for _, g in groupby(x, lambda n, c=count(): n-next(c)))
    return ",".join("{:d}-{:d}".format(i[0], i[-1]) if len(i) > 1 else "{:d}".format(i[0]) for i in grps) 
Example #11
Source File: train.py    From PathCon with MIT License 6 votes vote down vote up
def calculate_ranking_metrics(triplets, scores, true_relations):
    for i in range(scores.shape[0]):
        head, tail, relation = triplets[i]
        for j in true_relations[head, tail] - {relation}:
            scores[i, j] -= 1.0

    sorted_indices = np.argsort(-scores, axis=1)
    relations = np.array(triplets)[0:scores.shape[0], 2]
    sorted_indices -= np.expand_dims(relations, 1)
    zero_coordinates = np.argwhere(sorted_indices == 0)
    rankings = zero_coordinates[:, 1] + 1

    mrr = float(np.mean(1 / rankings))
    mr = float(np.mean(rankings))
    hit1 = float(np.mean(rankings <= 1))
    hit3 = float(np.mean(rankings <= 3))
    hit5 = float(np.mean(rankings <= 5))

    return mrr, mr, hit1, hit3, hit5 
Example #12
Source File: utils.py    From conv-social-pooling with MIT License 6 votes vote down vote up
def getHistory(self,vehId,t,refVehId,dsId):
        if vehId == 0:
            return np.empty([0,2])
        else:
            if self.T.shape[1]<=vehId-1:
                return np.empty([0,2])
            refTrack = self.T[dsId-1][refVehId-1].transpose()
            vehTrack = self.T[dsId-1][vehId-1].transpose()
            refPos = refTrack[np.where(refTrack[:,0]==t)][0,1:3]

            if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0:
                 return np.empty([0,2])
            else:
                stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - self.t_h)
                enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1
                hist = vehTrack[stpt:enpt:self.d_s,1:3]-refPos

            if len(hist) < self.t_h//self.d_s + 1:
                return np.empty([0,2])
            return hist



    ## Helper function to get track future 
Example #13
Source File: Spread.py    From pylops with GNU Lesser General Public License v3.0 6 votes vote down vote up
def _rmatvec_numpy(self, x):
        x = x.reshape(self.dimsd)
        y = np.zeros(self.dims, dtype=self.dtype)
        for it in range(self.dims[1]):
            for ix0 in range(self.dims[0]):
                if self.usetable:
                    indices = self.table[ix0, it]
                    if self.interp:
                        dindices = self.dtable[ix0, it]
                else:
                    if self.interp:
                        indices, dindices = self.fh(ix0, it)
                    else:
                        indices = self.fh(ix0, it)
                mask = np.argwhere(~np.isnan(indices))
                if mask.size > 0:
                    indices = (indices[mask]).astype(np.int)
                    if not self.interp:
                        y[ix0, it] = np.sum(x[mask, indices])
                    else:
                        y[ix0, it] = \
                            np.sum(x[mask, indices]*(1-dindices[mask])) + \
                            np.sum(x[mask, indices+1]*dindices[mask])
        return y.ravel() 
Example #14
Source File: pc_util.py    From H3DNet with MIT License 6 votes vote down vote up
def point_cloud_to_sem_vox(pt, sem_label, vs=0.06,xymin=-3.84, xymax=3.84, zmin=-0.2, zmax=2.68):
  pt[:,0]=pt[:,0]-xymin
  pt[:,1]=pt[:,1]-xymin
  pt[:,2]=pt[:,2]-zmin
  pt=pt/vs
  vxy=int((xymax-xymin)/vs)
  vz = int((zmax-zmin)/vs)
  pt = np.clip(pt, 0,vxy-1)
  pt[:,2] = np.clip(pt[:,2], 0,vz-1)
  vol=np.zeros((vxy,vxy,vz), np.float32)
  pt = pt.astype(np.int32)
  for i in range(pt.shape[0]):
    if sem_label[i] not in choose_classes:
      continue
    vol[pt[i,0], pt[i,1], pt[i,2]]=np.argwhere(choose_classes==sem_label[i])[0,0]+1
  return vol 
Example #15
Source File: Spread.py    From pylops with GNU Lesser General Public License v3.0 6 votes vote down vote up
def _matvec_numpy(self, x):
        x = x.reshape(self.dims)
        y = np.zeros(self.dimsd, dtype=self.dtype)
        for it in range(self.dims[1]):
            for ix0 in range(self.dims[0]):
                if self.usetable:
                    indices = self.table[ix0, it]
                    if self.interp:
                        dindices = self.dtable[ix0, it]
                else:
                    if self.interp:
                        indices, dindices = self.fh(ix0, it)
                    else:
                        indices = self.fh(ix0, it)
                mask = np.argwhere(~np.isnan(indices))
                if mask.size > 0:
                    indices = (indices[mask]).astype(np.int)
                    if not self.interp:
                        y[mask, indices] += x[ix0, it]
                    else:
                        y[mask, indices] += (1-dindices[mask])*x[ix0, it]
                        y[mask, indices + 1] += dindices[mask] * x[ix0, it]
        return y.ravel() 
Example #16
Source File: eval_helpers.py    From PoseWarper with Apache License 2.0 6 votes vote down vote up
def VOCap(rec,prec):

    mpre = np.zeros([1,2+len(prec)])
    mpre[0,1:len(prec)+1] = prec
    mrec = np.zeros([1,2+len(rec)])
    mrec[0,1:len(rec)+1] = rec
    mrec[0,len(rec)+1] = 1.0

    for i in range(mpre.size-2,-1,-1):
        mpre[0,i] = max(mpre[0,i],mpre[0,i+1])

    i = np.argwhere( ~np.equal( mrec[0,1:], mrec[0,:mrec.shape[1]-1]) )+1
    i = i.flatten()

    # compute area under the curve
    ap = np.sum( np.multiply( np.subtract( mrec[0,i], mrec[0,i-1]), mpre[0,i] ) )

    return ap 
Example #17
Source File: residual_plots.py    From gmpe-smtk with GNU Affero General Public License v3.0 6 votes vote down vote up
def _tojson(*numpy_objs):
    '''Utility function which returns a list where each element of numpy_objs
    is converted to its python equivalent (float or list)'''
    ret = []
    # problem: browsers might not be happy with JSON 'NAN', so convert
    # NaNs to None. Unfortunately, the conversion must be done element wise
    # in numpy (seems not to exist a pandas na filter):
    for obj in numpy_objs:
        isscalar = np.isscalar(obj)
        nan_indices = None if isscalar else \
            np.argwhere(np.isnan(obj)).flatten()
        # note: numpy.float64(N).tolist() returns a python float, so:
        obj = None if isscalar and np.isnan(obj) else obj.tolist()
        if nan_indices is not None:
            for idx in nan_indices:
                obj[idx] = None
        ret.append(obj)

    return ret  # tuple(_.tolist() for _ in numpy_objs) 
Example #18
Source File: drapes.py    From pycolab with Apache License 2.0 5 votes vote down vote up
def _whole_pattern_position(self, character, error_name):
      """Find the absolute location of `character` in game world ASCII art."""
      pos = list(np.argwhere(self._whole_pattern_art == ord(character)))
      if not pos: raise RuntimeError(
          '{} found no instances of {} in the pattern art used to build this '
          'PatternInfo object.'.format(error_name, repr(character)))
      if len(pos) > 1: raise RuntimeError(
          '{} found multiple instances of {} in the pattern art used to build '
          'this PatternInfo object.'.format(error_name, repr(character)))
      return tuple(pos[0]) 
Example #19
Source File: fictitious_play.py    From Nashpy with MIT License 5 votes vote down vote up
def get_best_response_to_play_count(A, play_count):
    """
    Returns the best response to a belief based on the playing distribution of the opponent
    """
    utilities = A @ play_count
    return np.random.choice(
        np.argwhere(utilities == np.max(utilities)).transpose()[0]
    ) 
Example #20
Source File: test_core.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def test_masked_array():
    a = np.ma.array([0, 1, 2, 3], mask=[0, 0, 1, 0])
    assert_equal(np.argwhere(a), [[1], [3]]) 
Example #21
Source File: Filters.py    From tf-pose with Apache License 2.0 5 votes vote down vote up
def adjustXPositions(self, pts, data):
        """Return a list of Point() where the x position is set to the nearest x value in *data* for each point in *pts*."""
        points = []
        timeIndices = []
        for p in pts:
            x = np.argwhere(abs(data - p.x()) == abs(data - p.x()).min())
            points.append(Point(data[x], p.y()))
            timeIndices.append(x)
            
        return points, timeIndices 
Example #22
Source File: test_numeric.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def test_2D(self):
        x = np.arange(6).reshape((2, 3))
        assert_array_equal(np.argwhere(x > 1),
                           [[0, 2],
                            [1, 0],
                            [1, 1],
                            [1, 2]]) 
Example #23
Source File: np_conserved.py    From tenpy with GNU General Public License v3.0 5 votes vote down vote up
def get_block(self, qindices, insert=False):
        """Return the ndarray in ``_data`` representing the block corresponding to `qindices`.

        Parameters
        ----------
        qindices : 1D array of np.intp
            The qindices, for which we need to look in _qdata.
        insert : bool
            If True, insert a new (zero) block, if `qindices` is not existent in ``self._data``.
            Otherwise just return ``None``.

        Returns
        -------
        block: ndarray | ``None``
            The block in ``_data`` corresponding to qindices.
            If `insert`=False and there is not block with qindices, return ``None``.

        Raises
        ------
        IndexError
            If `qindices` are incompatible with charge and `raise_incomp_q`.
        """
        if not np.all(self._get_block_charge(qindices) == self.qtotal):
            raise IndexError("trying to get block for qindices incompatible with charges")
        # find qindices in self._qdata
        match = np.argwhere(np.all(self._qdata == qindices, axis=1))[:, 0]
        if len(match) == 0:
            if insert:
                res = np.zeros(self._get_block_shape(qindices), dtype=self.dtype)
                self._data.append(res)
                self._qdata = np.append(self._qdata, [qindices], axis=0)
                self._qdata_sorted = False
                return res
            else:
                return None
        return self._data[match[0]] 
Example #24
Source File: cmath.py    From ehtplot with GNU General Public License v3.0 5 votes vote down vote up
def extrema(a):
    """Find extrema in an array"""
    da =  a[1:] -  a[:-1]
    xa = da[1:] * da[:-1]
    return np.argwhere(xa <= 0.0)[:,0]+1 
Example #25
Source File: cliff_walking.py    From reinforcement-learning-an-introduction with MIT License 5 votes vote down vote up
def sarsa(q_value, expected=False, step_size=ALPHA):
    state = START
    action = choose_action(state, q_value)
    rewards = 0.0
    while state != GOAL:
        next_state, reward = step(state, action)
        next_action = choose_action(next_state, q_value)
        rewards += reward
        if not expected:
            target = q_value[next_state[0], next_state[1], next_action]
        else:
            # calculate the expected value of new state
            target = 0.0
            q_next = q_value[next_state[0], next_state[1], :]
            best_actions = np.argwhere(q_next == np.max(q_next))
            for action_ in ACTIONS:
                if action_ in best_actions:
                    target += ((1.0 - EPSILON) / len(best_actions) + EPSILON / len(ACTIONS)) * q_value[next_state[0], next_state[1], action_]
                else:
                    target += EPSILON / len(ACTIONS) * q_value[next_state[0], next_state[1], action_]
        target *= GAMMA
        q_value[state[0], state[1], action] += step_size * (
                reward + target - q_value[state[0], state[1], action])
        state = next_state
        action = next_action
    return rewards

# an episode with Q-Learning
# @q_value: values for state action pair, will be updated
# @step_size: step size for updating
# @return: total rewards within this episode 
Example #26
Source File: test_numeric.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def test_list(self):
        assert_equal(np.argwhere([4, 0, 2, 1, 3]), [[0], [2], [3], [4]]) 
Example #27
Source File: sm_database.py    From gmpe-smtk with GNU Affero General Public License v3.0 5 votes vote down vote up
def _get_site_id(self, str_id):
        """
        TODO 
        """
        if str_id not in self.site_ids:
            self.site_ids.append(str_id)
        _id = np.argwhere(str_id == np.array(self.site_ids))[0]
        return _id[0] 
Example #28
Source File: tensor_format.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def __init__(self,
               criterion,
               description=None,
               font_attr=DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR):
    """Constructor of HighlightOptions.

    Args:
      criterion: (callable) A callable of the following signature:
        def to_highlight(X):
          # Args:
          #   X: The tensor to highlight elements in.
          #
          # Returns:
          #   (boolean ndarray) A boolean ndarray of the same shape as X
          #   indicating which elements are to be highlighted (iff True).
        This callable will be used as the argument of np.argwhere() to
        determine which elements of the tensor are to be highlighted.
      description: (str) Description of the highlight criterion embodied by
        criterion.
      font_attr: (str) Font attribute to be applied to the
        highlighted elements.

    """

    self.criterion = criterion
    self.description = description
    self.font_attr = font_attr 
Example #29
Source File: gene_expression.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def to_example_dict(encoder, inputs, mask, outputs):
  """Convert single h5 record to an example dict."""
  # Inputs
  bases = []
  input_ids = []
  last_idx = -1
  for row in np.argwhere(inputs):
    idx, base_id = row
    idx, base_id = int(idx), int(base_id)
    assert idx > last_idx  # if not, means 2 True values in 1 row
    # Some rows are all False. Those rows are mapped to UNK_ID.
    while idx != last_idx + 1:
      bases.append(encoder.UNK)
      last_idx += 1
    bases.append(encoder.BASES[base_id])
    last_idx = idx
  assert len(inputs) == len(bases)

  input_ids = encoder.encode(bases)
  input_ids.append(text_encoder.EOS_ID)

  # Targets: mask and output
  targets_mask = [float(v) for v in mask]
  # The output is (n, m); store targets_shape so that it can be reshaped
  # properly on the other end.
  targets = [float(v) for v in outputs.flatten()]
  targets_shape = [int(dim) for dim in outputs.shape]
  assert mask.shape[0] == outputs.shape[0]

  example_keys = ["inputs", "targets_mask", "targets", "targets_shape"]
  ex_dict = dict(
      zip(example_keys, [input_ids, targets_mask, targets, targets_shape]))
  return ex_dict 
Example #30
Source File: CEM_MAF_aen_PP.py    From AIX360 with Apache License 2.0 5 votes vote down vote up
def generate_PP(self, img_mask, orig_img, orig_class, model, mask_size, mask_mat):
        def model_prediction(model, inputs):
            prob = model.model.predict(inputs)
            predicted_class = np.argmax(prob)
            prob_str = np.array2string(prob).replace('\n','')
            return prob, predicted_class, prob_str        # ranking

        success = False
        print("Start ranking:")
        mask_vec = img_mask.reshape(-1)
        sort_idx = np.argsort(mask_vec)
        total_nonezero = len(np.argsort(mask_vec>0))
        working_mask = np.zeros((1,) + (mask_size, mask_size) + (1,))
        for i in range(1,total_nonezero):
            temp_index = sort_idx[-i]
            mask_position = np.argwhere(mask_mat[temp_index]==1)
            for index in mask_position:
                working_mask[(0,) + tuple(index) + (0,)] = 1
            adv_img = working_mask * orig_img
            img_prob, img_class, img_prob_str = model_prediction(model, adv_img)
            print("i:{}, index:{}, value:{}, class:{}".format(i, temp_index, mask_vec[temp_index], img_class))
            if img_class == orig_class:
                success = True
                break
                
        return adv_img, success