Python tensorflow.batch_gather() Examples

The following are 13 code examples of tensorflow.batch_gather(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: neigh_samplers.py    From DGFraud with Apache License 2.0 6 votes vote down vote up
def _call(self, inputs):
        eps = 0.001
        ids, num_samples, features, batch_size = inputs
        adj_lists = tf.gather(self.adj_info, ids)
        node_features = tf.gather(features, ids)
        feature_size = tf.shape(features)[-1]
        node_feature_repeat = tf.tile(node_features, [1,self.num_neighs])
        node_feature_repeat = tf.reshape(node_feature_repeat, [batch_size, self.num_neighs, feature_size])
        neighbor_feature =  tf.gather(features, adj_lists)
        distance = tf.sqrt(tf.reduce_sum(tf.square(node_feature_repeat - neighbor_feature), -1))
        prob = tf.exp(-distance)
        prob_sum = tf.reduce_sum(prob, -1, keepdims=True)
        prob_sum = tf.tile(prob_sum, [1,self.num_neighs])
        prob = tf.divide(prob, prob_sum)
        prob = tf.where(prob>eps, prob, 0*prob) # uncommenting this line to use eps to filter small probabilities
        samples_idx = tf.random.categorical(tf.math.log(prob), num_samples)
        selected = tf.batch_gather(adj_lists, samples_idx)
        return selected 
Example #2
Source File: neural_dater.py    From NeuralDater with Apache License 2.0 6 votes vote down vote up
def gather(self, data, pl_idx, pl_mask, max_len, name=None):
		"""
		Lookup equivalent for tensors with dim > 2 (Can be simplified using tf.batch_gather)

		Parameters
		----------
		data:		Tensor in which lookup has to be performed
		pl_idx:		The indices to be taken
		pl_mask:	For handling padding in pl_idx
		max_len:	Maximum length of indices

		Returns
		-------
		et_vecs * mask_vec:	Extracted vectors at given indices
		
		"""
		idx1  = tf.range(self.p.batch_size, dtype=tf.int32)
		idx1  = tf.reshape(idx1, [-1, 1])
		idx1_ = tf.reshape(tf.tile(idx1, [1, max_len]) , [-1, 1])
		idx_reshape = tf.reshape(pl_idx, [-1, 1])
		indices = tf.concat((idx1_, idx_reshape), axis=1)
		et_vecs = tf.gather_nd(data, indices)
		et_vecs = tf.reshape(et_vecs, [self.p.batch_size, self.max_et, -1])
		mask_vec = tf.expand_dims(pl_mask, axis=2)
		return et_vecs * mask_vec 
Example #3
Source File: augdesc.py    From pyslam with GNU General Public License v3.0 6 votes vote down vote up
def _interpolate(self, xy1, xy2, points2):
        batch_size = tf.shape(xy1)[0]
        ndataset1 = tf.shape(xy1)[1]

        eps = 1e-6
        dist_mat = tf.matmul(xy1, xy2, transpose_b=True)
        norm1 = tf.reduce_sum(xy1 * xy1, axis=-1, keepdims=True)
        norm2 = tf.reduce_sum(xy2 * xy2, axis=-1, keepdims=True)
        dist_mat = tf.sqrt(norm1 - 2 * dist_mat + tf.linalg.matrix_transpose(norm2) + eps)
        dist, idx = tf.math.top_k(tf.negative(dist_mat), k=3)

        dist = tf.maximum(dist, 1e-10)
        norm = tf.reduce_sum((1.0 / dist), axis=2, keepdims=True)
        norm = tf.tile(norm, [1, 1, 3])
        weight = (1.0 / dist) / norm
        idx = tf.reshape(idx, (batch_size, -1))
        nn_points = tf.batch_gather(points2, idx)
        nn_points = tf.reshape(nn_points, (batch_size, ndataset1, 3, points2.get_shape()[-1].value))
        interpolated_points = tf.reduce_sum(weight[..., tf.newaxis] * nn_points, axis=-2)

        return interpolated_points 
Example #4
Source File: modeling.py    From grover with Apache License 2.0 5 votes vote down vote up
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
    """
    Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)
        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        k_expanded = k if isinstance(k, int) else k[:, None]
        exclude_mask = tf.range(vocab_size)[None] >= k_expanded

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

    return {
        'probs': probs,
        'sample': sample,
    } 
Example #5
Source File: nuelus_sampling_utils.py    From BERT with Apache License 2.0 5 votes vote down vote up
def nucleus_sampling(logits, vocab_size, p=0.9, 
					input_ids=None, input_ori_ids=None,
					**kargs):
	input_shape_list = bert_utils.get_shape_list(logits, expected_rank=[2,3])
	if len(input_shape_list) == 3:
		logits = tf.reshape(logits, (-1, vocab_size))
	probs = tf.nn.softmax(logits, axis=-1)
	# [batch_size, seq, vocab_perm]
	# indices = tf.argsort(probs, direction='DESCENDING')
	indices = tf.contrib.framework.argsort(probs, direction='DESCENDING')

	cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)
	
	# find the top pth index to cut off. careful we don't want to cutoff everything!
	# result will be [batch_size, seq, vocab_perm]
	exclude_mask = tf.logical_not(
	tf.logical_or(cumulative_probabilities < p, tf.range(vocab_size)[None] < 1))
	exclude_mask = tf.cast(exclude_mask, tf.float32)

	indices_v1 = tf.contrib.framework.argsort(indices)
	exclude_mask = reorder(exclude_mask, tf.cast(indices_v1, dtype=tf.int32))
	if len(input_shape_list) == 3:
		exclude_mask = tf.reshape(exclude_mask, input_shape_list)
		# logits = tf.reshape(logits, input_shape_list)

	if input_ids is not None and input_ori_ids is not None:
		exclude_mask, input_ori_ids = get_extra_mask(
								input_ids, input_ori_ids, 
								exclude_mask, vocab_size,
								**kargs)

		return [exclude_mask, input_ori_ids]
	else:
		return [exclude_mask] 
Example #6
Source File: backend.py    From bert4keras with Apache License 2.0 5 votes vote down vote up
def batch_gather(params, indices):
    """同tf旧版本的batch_gather
    """
    try:
        return tf.gather(params, indices, batch_dims=K.ndim(indices) - 1)
    except Exception as e1:
        try:
            return tf.batch_gather(params, indices)
        except Exception as e2:
            raise ValueError('%s\n%s\n' % (e1.message, e2.message)) 
Example #7
Source File: modeling.py    From grover with Apache License 2.0 4 votes vote down vote up
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
    """
    Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)

        if isinstance(p, float) and p > 0.999999:
            # Don't do top-p sampling in this case
            print("Top-p sampling DISABLED", flush=True)
            return {
                'probs': probs,
                'sample': tf.random.categorical(
                    logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                    num_samples=num_samples, dtype=tf.int32),
            }

        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')
        cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        p_expanded = p if isinstance(p, float) else p[:, None]
        exclude_mask = tf.logical_not(
            tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

        # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
        # unperm_indices = tf.argsort(indices, direction='ASCENDING')
        # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
        # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
        # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)

    return {
        'probs': probs,
        'sample': sample,
    } 
Example #8
Source File: modeling.py    From grover with Apache License 2.0 4 votes vote down vote up
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
    """
    Helper function that samples from grover for a single step
    :param tokens: [batch_size, n_ctx_b] tokens that we will predict from
    :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
    :param news_config: config for the GroverModel
    :param batch_size: batch size to use
    :param p_for_topp: top-p or top-k threshold
    :param cache: [batch_size, news_config.num_hidden_layers, 2,
                   news_config.num_attention_heads, n_ctx_a,
                   news_config.hidden_size // news_config.num_attention_heads] OR, None
    :return: new_tokens, size [batch_size]
             new_probs, also size [batch_size]
             new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
                   news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
    """
    model = GroverModel(
        config=news_config,
        is_training=False,
        input_ids=tokens,
        reuse=tf.AUTO_REUSE,
        scope='newslm',
        chop_off_last_token=False,
        do_cache=True,
        cache=cache,
    )

    # Extract the FINAL SEQ LENGTH
    batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
    next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]

    if do_topk:
        sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
    else:
        sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)

    new_tokens = tf.squeeze(sample_info['sample'], 1)
    new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
    return {
        'new_tokens': new_tokens,
        'new_probs': new_probs,
        'new_cache': model.new_kvs,
    } 
Example #9
Source File: beam_search.py    From BERT with Apache License 2.0 4 votes vote down vote up
def fast_tpu_gather(params, indices, name=None):
  """Fast gather implementation for models running on TPU.

  This function use one_hot and batch matmul to do gather, which is faster
  than gather_nd on TPU. For params that have dtype of int32 (sequences to
  gather from), batch_gather is used to keep accuracy.

  Args:
    params: A tensor from which to gather values.
      [batch_size, original_size, ...]
    indices: A tensor used as the index to gather values.
      [batch_size, selected_size].
    name: A string, name of the operation (optional).

  Returns:
    gather_result: A tensor that has the same rank as params.
      [batch_size, selected_size, ...]
  """
  with tf.name_scope(name):
    dtype = params.dtype

    def _gather(params, indices):
      """Fast gather using one_hot and batch matmul."""
      if dtype != tf.float32:
        params = tf.to_float(params)
      shape = common_layers.shape_list(params)
      indices_shape = common_layers.shape_list(indices)
      ndims = params.shape.ndims
      # Adjust the shape of params to match one-hot indices, which is the
      # requirement of Batch MatMul.
      if ndims == 2:
        params = tf.expand_dims(params, axis=-1)
      if ndims > 3:
        params = tf.reshape(params, [shape[0], shape[1], -1])
      gather_result = tf.matmul(
          tf.one_hot(indices, shape[1], dtype=params.dtype), params)
      if ndims == 2:
        gather_result = tf.squeeze(gather_result, axis=-1)
      if ndims > 3:
        shape[1] = indices_shape[1]
        gather_result = tf.reshape(gather_result, shape)
      if dtype != tf.float32:
        gather_result = tf.cast(gather_result, dtype)
      return gather_result

    # If the dtype is int, use the gather instead of one_hot matmul to avoid
    # precision loss. The max int value can be represented by bfloat16 in MXU is
    # 256, which is smaller than the possible id values. Encoding/decoding can
    # potentially used to make it work, but the benenfit is small right now.
    if dtype.is_integer:
      gather_result = tf.batch_gather(params, indices)
    else:
      gather_result = _gather(params, indices)

    return gather_result 
Example #10
Source File: beam_search.py    From training_results_v0.5 with Apache License 2.0 4 votes vote down vote up
def fast_tpu_gather(params, indices, name=None):
  """Fast gather implementation for models running on TPU.

  This function use one_hot and batch matmul to do gather, which is faster
  than gather_nd on TPU. For params that have dtype of int32 (sequences to
  gather from), batch_gather is used to keep accuracy.

  Args:
    params: A tensor from which to gather values.
      [batch_size, original_size, ...]
    indices: A tensor used as the index to gather values.
      [batch_size, selected_size].
    name: A string, name of the operation (optional).

  Returns:
    gather_result: A tensor that has the same rank as params.
      [batch_size, selected_size, ...]
  """
  with tf.name_scope(name):
    dtype = params.dtype

    def _gather(params, indices):
      """Fast gather using one_hot and batch matmul."""
      if dtype != tf.float32:
        params = tf.to_float(params)
      shape = common_layers.shape_list(params)
      indices_shape = common_layers.shape_list(indices)
      ndims = params.shape.ndims
      # Adjust the shape of params to match one-hot indices, which is the
      # requirement of Batch MatMul.
      if ndims == 2:
        params = tf.expand_dims(params, axis=-1)
      if ndims > 3:
        params = tf.reshape(params, [shape[0], shape[1], -1])
      gather_result = tf.matmul(
          tf.one_hot(indices, shape[1], dtype=params.dtype), params)
      if ndims == 2:
        gather_result = tf.squeeze(gather_result, axis=-1)
      if ndims > 3:
        shape[1] = indices_shape[1]
        gather_result = tf.reshape(gather_result, shape)
      if dtype != tf.float32:
        gather_result = tf.cast(gather_result, dtype)
      return gather_result

    # If the dtype is int32, use the gather instead of one_hot matmul to avoid
    # precision loss. The max int value can be represented by bfloat16 in MXU is
    # 256, which is smaller than the possible id values. Encoding/decoding can
    # potentially used to make it work, but the benenfit is small right now.
    if dtype == tf.int32:
      gather_result = tf.batch_gather(params, indices)
    else:
      gather_result = _gather(params, indices)

    return gather_result 
Example #11
Source File: beam_search.py    From training_results_v0.5 with Apache License 2.0 4 votes vote down vote up
def fast_tpu_gather(params, indices, name=None):
  """Fast gather implementation for models running on TPU.

  This function use one_hot and batch matmul to do gather, which is faster
  than gather_nd on TPU. For params that have dtype of int32 (sequences to
  gather from), batch_gather is used to keep accuracy.

  Args:
    params: A tensor from which to gather values.
      [batch_size, original_size, ...]
    indices: A tensor used as the index to gather values.
      [batch_size, selected_size].
    name: A string, name of the operation (optional).

  Returns:
    gather_result: A tensor that has the same rank as params.
      [batch_size, selected_size, ...]
  """
  with tf.name_scope(name):
    dtype = params.dtype

    def _gather(params, indices):
      """Fast gather using one_hot and batch matmul."""
      if dtype != tf.float32:
        params = tf.to_float(params)
      shape = common_layers.shape_list(params)
      indices_shape = common_layers.shape_list(indices)
      ndims = params.shape.ndims
      # Adjust the shape of params to match one-hot indices, which is the
      # requirement of Batch MatMul.
      if ndims == 2:
        params = tf.expand_dims(params, axis=-1)
      if ndims > 3:
        params = tf.reshape(params, [shape[0], shape[1], -1])
      gather_result = tf.matmul(
          tf.one_hot(indices, shape[1], dtype=params.dtype), params)
      if ndims == 2:
        gather_result = tf.squeeze(gather_result, axis=-1)
      if ndims > 3:
        shape[1] = indices_shape[1]
        gather_result = tf.reshape(gather_result, shape)
      if dtype != tf.float32:
        gather_result = tf.cast(gather_result, dtype)
      return gather_result

    # If the dtype is int32, use the gather instead of one_hot matmul to avoid
    # precision loss. The max int value can be represented by bfloat16 in MXU is
    # 256, which is smaller than the possible id values. Encoding/decoding can
    # potentially used to make it work, but the benenfit is small right now.
    if dtype == tf.int32:
      gather_result = tf.batch_gather(params, indices)
    else:
      gather_result = _gather(params, indices)

    return gather_result 
Example #12
Source File: net.py    From gcdn with MIT License 4 votes vote down vote up
def gconv(self, h, name, in_feat, out_feat, stride_th1, stride_th2, compute_graph=True, return_graph=False, D=[]):

		if compute_graph:
			D = self.compute_graph(h)

		_, top_idx = tf.nn.top_k(-D, self.config.min_nn+1) # (B, N, d+1)
		top_idx2 = tf.reshape(tf.tile(tf.expand_dims(top_idx[:,:,0],2), [1, 1, self.config.min_nn-8]), [-1, self.N*(self.config.min_nn-8)]) # (B, N*d)
		top_idx = tf.reshape(top_idx[:,:,9:],[-1, self.N*(self.config.min_nn-8)]) # (B, N*d)

		x_tilde1 = tf.batch_gather(h, top_idx) # (B, K, dlm1)		
		x_tilde2 = tf.batch_gather(h, top_idx2) # (B, K, dlm1)
		labels = x_tilde1 - x_tilde2 # (B, K, dlm1)
		x_tilde1 = tf.reshape(x_tilde1, [-1, in_feat]) # (B*K, dlm1)
		labels = tf.reshape(labels, [-1, in_feat]) # (B*K, dlm1)
		d_labels = tf.reshape( tf.reduce_sum(labels*labels, 1), [-1, self.config.min_nn-8]) # (B*N, d)

		name_flayer = name + "_flayer0"
		labels = tf.nn.leaky_relu(tf.matmul(labels, self.W[name_flayer]) + self.b[name_flayer]) #  (B*K, F)
		name_flayer = name + "_flayer1"
		labels_exp = tf.expand_dims(labels, 1) # (B*K, 1, F)
		labels1 = labels_exp+0.0
		for ss in range(1, in_feat/stride_th1):
			labels1 = tf.concat( [labels1, self.myroll(labels_exp, shift=(ss+1)*stride_th1, axis=2)], axis=1 ) # (B*K, dlm1/stride, dlm1)
		labels2 = labels_exp+0.0
		for ss in range(1, out_feat/stride_th2):
			labels2 = tf.concat( [labels2, self.myroll(labels_exp, shift=(ss+1)*stride_th2, axis=2)], axis=1 ) # (B*K, dl/stride, dlm1)
		theta1 = tf.matmul( tf.reshape(labels1, [-1, in_feat]), self.W[name_flayer+"_th1"] )  # (B*K*dlm1/stride, R*stride)
		theta1 = tf.reshape(theta1, [-1, self.config.rank_theta, in_feat] ) + self.b[name_flayer+"_th1"]
		theta2 = tf.matmul( tf.reshape(labels2, [-1, in_feat]), self.W[name_flayer+"_th2"] )  # (B*K*dl/stride, R*stride)
		theta2 = tf.reshape(theta2, [-1, self.config.rank_theta,  out_feat] ) + self.b[name_flayer+"_th2"]	
		thetal = tf.expand_dims( tf.matmul(labels, self.W[name_flayer+"_thl"]) + self.b[name_flayer+"_thl"], 2 ) # (B*K, R, 1)

		x = tf.matmul(theta1, tf.expand_dims(x_tilde1,2)) # (B*K, R, 1)
		x = tf.multiply(x, thetal) # (B*K, R, 1)
		x = tf.matmul(theta2, x, transpose_a=True)[:,:,0] # (B*K, dl)

		x = tf.reshape(x, [-1, self.config.min_nn-8, out_feat]) # (N, d, dl)
		x = tf.multiply(x, tf.expand_dims(tf.exp(-tf.div(d_labels,10)),2)) # (N, d, dl)
		x = tf.reduce_mean(x, 1) # (N, dl)
		x = tf.reshape(x,[-1, self.N, out_feat]) # (B, N, dl)
		
		if return_graph:
			return x, D
		else:
			return x 
Example #13
Source File: net_conv2.py    From gcdn with MIT License 4 votes vote down vote up
def gconv_conv_inner(self, h, name, in_feat, out_feat, stride_th1, stride_th2, compute_graph=True, return_graph=False, D=[]):

		h = tf.expand_dims(h, 0) # (1,M,dl)
		p = tf.image.extract_image_patches(h, ksizes=[1, self.config.search_window[0], self.config.search_window[1], 1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID") # (1,X,Y,dlm1*W)
		p = tf.reshape(p,[-1, self.config.search_window[0], self.config.search_window[1], in_feat]) 
		p = tf.reshape(p,[-1, self.config.searchN, in_feat]) # (N,W,dlm1)

		if compute_graph:
			D = tf.map_fn(lambda feat: self.gconv_conv_inner2(feat), tf.reshape(p,[self.config.search_window[0],self.config.search_window[1],self.config.searchN, in_feat]), parallel_iterations=16, swap_memory=False) # (B,N/B,W)
			D = tf.reshape(D,[-1, self.config.searchN]) # (N,W)

		_, top_idx = tf.nn.top_k(-D, self.config.min_nn+1) # (N, d+1)
		#top_idx2 = tf.reshape(tf.tile(tf.expand_dims(top_idx[:,0],1), [1, self.config.min_nn[i]]), [-1])
		top_idx2 = tf.tile(tf.expand_dims(top_idx[:,0],1), [1, self.config.min_nn-8]) # (N, d)
		#top_idx = tf.reshape(top_idx[:,1:],[-1]) # (N*d,)
		top_idx = top_idx[:,9:] # (N, d)

		x_tilde1 = tf.batch_gather(p, top_idx) # (N, d, dlm1)	
		x_tilde1 = tf.reshape(x_tilde1, [-1, in_feat]) # (K, dlm1)
		x_tilde2 = tf.batch_gather(p, top_idx2) # (N, d, dlm1)
		x_tilde2 = tf.reshape(x_tilde2, [-1, in_feat]) # (K, dlm1)

		labels = x_tilde1 - x_tilde2 # (K, dlm1)
		d_labels = tf.reshape( tf.reduce_sum(labels*labels, 1), [-1, self.config.min_nn-8]) # (N, d)

		name_flayer = name + "_flayer0"
		labels = tf.nn.leaky_relu(tf.matmul(labels, self.W[name_flayer]) + self.b[name_flayer])
		name_flayer = name + "_flayer1"
		labels_exp = tf.expand_dims(labels, 1) # (B*K, 1, F)
		labels1 = labels_exp+0.0
		for ss in range(1, in_feat/stride_th1):
			labels1 = tf.concat( [labels1, self.myroll(labels_exp, shift=(ss+1)*stride_th1, axis=2)], axis=1 ) # (B*K, dlm1/stride, dlm1)
		labels2 = labels_exp+0.0
		for ss in range(1, out_feat/stride_th2):
			labels2 = tf.concat( [labels2, self.myroll(labels_exp, shift=(ss+1)*stride_th2, axis=2)], axis=1 ) # (B*K, dl/stride, dlm1)
		theta1 = tf.matmul( tf.reshape(labels1, [-1, in_feat]), self.W[name_flayer+"_th1"] )  # (B*K*dlm1/stride, R*stride)
		theta1 = tf.reshape(theta1, [-1, self.config.rank_theta, in_feat] ) + self.b[name_flayer+"_th1"]
		theta2 = tf.matmul( tf.reshape(labels2, [-1, in_feat]), self.W[name_flayer+"_th2"] )  # (B*K*dl/stride, R*stride)
		theta2 = tf.reshape(theta2, [-1, self.config.rank_theta,  out_feat] ) + self.b[name_flayer+"_th2"]	
		thetal = tf.expand_dims( tf.matmul(labels, self.W[name_flayer+"_thl"]) + self.b[name_flayer+"_thl"], 2 ) # (B*K, R, 1)

		x = tf.matmul(theta1, tf.expand_dims(x_tilde1,2)) # (K, R, 1)
		x = tf.multiply(x, thetal) # (K, R, 1)
		x = tf.matmul(theta2, x, transpose_a=True)[:,:,0] # (K, dl)

		x = tf.reshape(x, [-1, self.config.min_nn-8, out_feat]) # (N, d, dl)
		x = tf.multiply(x, tf.expand_dims(tf.exp(-tf.div(d_labels,10)),2)) # (N, d, dl)
		x = tf.reduce_mean(x, 1) # (N, dl)

		x = tf.expand_dims(x,0) # (1, N, dl)

		return [x, D]