Python numpy.prod() Examples
The following are 30 code examples for showing how to use numpy.prod(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
You may check out the related API usage on the sidebar.
You may also want to check out all available functions/classes of the module
numpy
, or try the search function
.
Example 1
Project: Att-ChemdNER Author: lingluodlut File: initializations.py License: Apache License 2.0 | 6 votes |
def get_fans(shape, dim_ordering='th'): if len(shape) == 2: fan_in = shape[0] fan_out = shape[1] elif len(shape) == 4 or len(shape) == 5: # assuming convolution kernels (2D or 3D). # TH kernel shape: (depth, input_depth, ...) # TF kernel shape: (..., input_depth, depth) if dim_ordering == 'th': receptive_field_size = np.prod(shape[2:]) fan_in = shape[1] * receptive_field_size fan_out = shape[0] * receptive_field_size elif dim_ordering == 'tf': receptive_field_size = np.prod(shape[:2]) fan_in = shape[-2] * receptive_field_size fan_out = shape[-1] * receptive_field_size else: raise ValueError('Invalid dim_ordering: ' + dim_ordering) else: # no specific assumptions fan_in = np.sqrt(np.prod(shape)) fan_out = np.sqrt(np.prod(shape)) return fan_in, fan_out
Example 2
Project: disentangling_conditional_gans Author: zalandoresearch File: tfutil.py License: MIT License | 6 votes |
def print_layers(self, title=None, hide_layers_with_no_params=False): if title is None: title = self.name print() print('%-28s%-12s%-24s%-24s' % (title, 'Params', 'OutputShape', 'WeightShape')) print('%-28s%-12s%-24s%-24s' % (('---',) * 4)) total_params = 0 for layer_name, layer_output, layer_trainables in self.list_layers(): weights = [var for var in layer_trainables if var.name.endswith('/weight:0')] num_params = sum(np.prod(shape_to_list(var.shape)) for var in layer_trainables) total_params += num_params if hide_layers_with_no_params and num_params == 0: continue print('%-28s%-12s%-24s%-24s' % ( layer_name, num_params if num_params else '-', layer_output.shape, weights[0].shape if len(weights) == 1 else '-')) print('%-28s%-12s%-24s%-24s' % (('---',) * 4)) print('%-28s%-12s%-24s%-24s' % ('Total', total_params, '', '')) print() # Construct summary ops to include histograms of all trainable parameters in TensorBoard.
Example 3
Project: disentangling_conditional_gans Author: zalandoresearch File: util_scripts.py License: MIT License | 6 votes |
def generate_fake_images(run_id, snapshot=None, grid_size=[1,1], num_pngs=1, image_shrink=1, png_prefix=None, random_seed=1000, minibatch_size=8): network_pkl = misc.locate_network_pkl(run_id, snapshot) if png_prefix is None: png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-' random_state = np.random.RandomState(random_seed) print('Loading network from "%s"...' % network_pkl) G, D, Gs = misc.load_network_pkl(run_id, snapshot) result_subdir = misc.create_result_subdir(config.result_dir, config.desc) for png_idx in range(num_pngs): print('Generating png %d / %d...' % (png_idx, num_pngs)) latents = misc.random_latents(np.prod(grid_size), Gs, random_state=random_state) labels = np.zeros([latents.shape[0], 0], np.float32) images = Gs.run(latents, labels, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8) misc.save_image_grid(images, os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx)), [0,255], grid_size) open(os.path.join(result_subdir, '_done.txt'), 'wt').close() #---------------------------------------------------------------------------- # Generate MP4 video of random interpolations using a previously trained network. # To run, uncomment the appropriate line in config.py and launch train.py.
Example 4
Project: dustmaps Author: gregreen File: map_base.py License: GNU General Public License v2.0 | 6 votes |
def ensure_flat_frame(f, frame=None): def _wrapper_func(self, coords, **kwargs): if (frame is not None) and (coords.frame.name != frame): coords_transf = coords.transform_to(frame) else: coords_transf = coords is_array = not coords.isscalar if is_array: orig_shape = coords.shape shape_flat = (np.prod(orig_shape),) coords_transf = coords_to_shape(coords_transf, shape_flat) else: coords_transf = coords_to_shape(coords_transf, (1,)) out = f(self, coords_transf, **kwargs) if is_array: out.shape = orig_shape + out.shape[1:] else: out = out[0] return out return _wrapper_func
Example 5
Project: neuropythy Author: noahbenson File: core.py License: GNU Affero General Public License v3.0 | 6 votes |
def fapply(f, x, tz=False): ''' fapply(f,x) yields the result of applying f either to x, if x is a normal value or array, or to x.data if x is a sparse matrix. Does not modify x (unless f modifiex x). The optional argument tz (default: False) may be set to True to specify that, if x is a sparse matrix that contains at least 1 element that is a sparse-zero, then f(0) should replace all the sparse-zeros in x (unless f(0) == 0). ''' if sps.issparse(x): y = x.copy() y.data = f(x.data) if tz and y.getnnz() < np.prod(y.shape): z = f(np.array(0)) if z != 0: y = y.toarray() y[y == 0] = z return y else: return f(x)
Example 6
Project: dynamic-training-with-apache-mxnet-on-aws Author: awslabs File: metric.py License: Apache License 2.0 | 6 votes |
def update(self, labels, preds): """Updates the internal evaluation result. Parameters ---------- labels : list of `NDArray` The labels of the data. preds : list of `NDArray` Predicted values. """ labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() pred = pred.asnumpy() if len(label.shape) == 1: label = label.reshape(label.shape[0], 1) if len(pred.shape) == 1: pred = pred.reshape(pred.shape[0], 1) self.sum_metric += numpy.abs(label - pred).mean() self.num_inst += 1 # numpy.prod(label.shape)
Example 7
Project: dynamic-training-with-apache-mxnet-on-aws Author: awslabs File: metric.py License: Apache License 2.0 | 6 votes |
def update(self, labels, preds): """Updates the internal evaluation result. Parameters ---------- labels : list of `NDArray` The labels of the data. preds : list of `NDArray` Predicted values. """ labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() pred = pred.asnumpy() if len(label.shape) == 1: label = label.reshape(label.shape[0], 1) if len(pred.shape) == 1: pred = pred.reshape(pred.shape[0], 1) self.sum_metric += ((label - pred)**2.0).mean() self.num_inst += 1 # numpy.prod(label.shape)
Example 8
Project: dynamic-training-with-apache-mxnet-on-aws Author: awslabs File: initializer.py License: Apache License 2.0 | 6 votes |
def _init_weight(self, name, arr): shape = arr.shape hw_scale = 1. if len(shape) < 2: raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at' ' least 2D.'.format(name)) if len(shape) > 2: hw_scale = np.prod(shape[2:]) fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale factor = 1. if self.factor_type == "avg": factor = (fan_in + fan_out) / 2.0 elif self.factor_type == "in": factor = fan_in elif self.factor_type == "out": factor = fan_out else: raise ValueError("Incorrect factor type") scale = np.sqrt(self.magnitude / factor) if self.rnd_type == "uniform": random.uniform(-scale, scale, out=arr) elif self.rnd_type == "gaussian": random.normal(0, scale, out=arr) else: raise ValueError("Unknown random type")
Example 9
Project: dynamic-training-with-apache-mxnet-on-aws Author: awslabs File: ndarray.py License: Apache License 2.0 | 6 votes |
def size(self): """Number of elements in the array. Equivalent to the product of the array's dimensions. Examples -------- >>> import numpy as np >>> x = mx.nd.zeros((3, 5, 2)) >>> x.size 30 >>> np.prod(x.shape) 30 """ size = 1 for i in self.shape: size *= i return size
Example 10
Project: dynamic-training-with-apache-mxnet-on-aws Author: awslabs File: parameter.py License: Apache License 2.0 | 6 votes |
def _finish_deferred_init(self): """Finishes deferred initialization.""" if not self._deferred_init: return init, ctx, default_init, data = self._deferred_init self._deferred_init = () assert self.shape is not None and np.prod(self.shape) > 0, \ "Cannot initialize Parameter '%s' because it has " \ "invalid shape: %s. Please specify in_units, " \ "in_channels, etc for `Block`s."%( self.name, str(self.shape)) with autograd.pause(): if data is None: data = ndarray.zeros(shape=self.shape, dtype=self.dtype, ctx=context.cpu(), stype=self._stype) initializer.create(default_init)( initializer.InitDesc(self.name, {'__init__': init}), data) self._init_impl(data, ctx)
Example 11
Project: dynamic-training-with-apache-mxnet-on-aws Author: awslabs File: test_sparse_ndarray.py License: Apache License 2.0 | 6 votes |
def test_sparse_nd_storage_fallback(): def check_output_fallback(shape): ones = mx.nd.ones(shape) out = mx.nd.zeros(shape=shape, stype='csr') mx.nd.broadcast_add(ones, ones * 2, out=out) assert(np.sum(out.asnumpy() - 3) == 0) def check_input_fallback(shape): ones = mx.nd.ones(shape) out = mx.nd.broadcast_add(ones.tostype('csr'), ones.tostype('row_sparse')) assert(np.sum(out.asnumpy() - 2) == 0) def check_fallback_with_temp_resource(shape): ones = mx.nd.ones(shape) out = mx.nd.sum(ones) assert(out.asscalar() == np.prod(shape)) shape = rand_shape_2d() check_output_fallback(shape) check_input_fallback(shape) check_fallback_with_temp_resource(shape)
Example 12
Project: soccer-matlab Author: utra-robosoccer File: count_weights.py License: BSD 2-Clause "Simplified" License | 6 votes |
def count_weights(scope=None, exclude=None, graph=None): """Count learnable parameters. Args: scope: Resrict the count to a variable scope. exclude: Regex to match variable names to exclude. graph: Operate on a graph other than the current default graph. Returns: Number of learnable parameters as integer. """ if scope: scope = scope if scope.endswith('/') else scope + '/' graph = graph or tf.get_default_graph() vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if scope: vars_ = [var for var in vars_ if var.name.startswith(scope)] if exclude: exclude = re.compile(exclude) vars_ = [var for var in vars_ if not exclude.match(var.name)] shapes = [var.get_shape().as_list() for var in vars_] return int(sum(np.prod(shape) for shape in shapes))
Example 13
Project: soccer-matlab Author: utra-robosoccer File: count_weights.py License: BSD 2-Clause "Simplified" License | 6 votes |
def count_weights(scope=None, exclude=None, graph=None): """Count learnable parameters. Args: scope: Resrict the count to a variable scope. exclude: Regex to match variable names to exclude. graph: Operate on a graph other than the current default graph. Returns: Number of learnable parameters as integer. """ if scope: scope = scope if scope.endswith('/') else scope + '/' graph = graph or tf.get_default_graph() vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if scope: vars_ = [var for var in vars_ if var.name.startswith(scope)] if exclude: exclude = re.compile(exclude) vars_ = [var for var in vars_ if not exclude.match(var.name)] shapes = [var.get_shape().as_list() for var in vars_] return int(sum(np.prod(shape) for shape in shapes))
Example 14
Project: Recipes Author: Lasagne File: densenet.py License: MIT License | 6 votes |
def sample(self, shape): import numpy as np rng = lasagne.random.get_rng() if len(shape) >= 4: # convolutions use Gaussians with stddev of sqrt(2/fan_out), see # https://github.com/liuzhuang13/DenseNet/blob/cbb6bff/densenet.lua#L85-L86 # and https://github.com/facebook/fb.resnet.torch/issues/106 fan_out = shape[0] * np.prod(shape[2:]) W = rng.normal(0, np.sqrt(2. / fan_out), size=shape) elif len(shape) == 2: # the dense layer uses Uniform of range sqrt(1/fan_in), see # https://github.com/torch/nn/blob/651103f/Linear.lua#L21-L43 fan_in = shape[0] W = rng.uniform(-np.sqrt(1. / fan_in), np.sqrt(1. / fan_in), size=shape) return lasagne.utils.floatX(W)
Example 15
Project: lirpg Author: Hwhitetooth File: utils.py License: MIT License | 6 votes |
def ortho_init(scale=1.0): def _ortho_init(shape, dtype, partition_info=None): #lasagne ortho init for tf shape = tuple(shape) if len(shape) == 2: flat_shape = shape elif len(shape) == 4: # assumes NHWC flat_shape = (np.prod(shape[:-1]), shape[-1]) else: raise NotImplementedError a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) q = u if u.shape == flat_shape else v # pick the one with the correct shape q = q.reshape(shape) return (scale * q[:shape[0], :shape[1]]).astype(np.float32) return _ortho_init
Example 16
Project: torch-toolbox Author: PistonY File: summary.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def _cac_conv(layer, input, output): # bs, ic, ih, iw = input[0].shape oh, ow = output.shape[-2:] kh, kw = layer.kernel_size ic, oc = layer.in_channels, layer.out_channels g = layer.groups tb_params = 0 ntb__params = 0 flops = 0 if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): params = np.prod(layer.weight.shape) t, n = _cac_grad_params(params, layer.weight) tb_params += t ntb__params += n flops += (2 * ic * kh * kw - 1) * oh * ow * (oc // g) if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): params = np.prod(layer.bias.shape) t, n = _cac_grad_params(params, layer.bias) tb_params += t ntb__params += n flops += oh * ow * (oc // g) return tb_params, ntb__params, flops
Example 17
Project: torch-toolbox Author: PistonY File: summary.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def _cac_xx_norm(layer, input, output): tb_params = 0 ntb__params = 0 if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): params = np.prod(layer.weight.shape) t, n = _cac_grad_params(params, layer.weight) tb_params += t ntb__params += n if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): params = np.prod(layer.bias.shape) t, n = _cac_grad_params(params, layer.bias) tb_params += t ntb__params += n if hasattr(layer, 'running_mean') and hasattr(layer.running_mean, 'shape'): params = np.prod(layer.running_mean.shape) ntb__params += params if hasattr(layer, 'running_var') and hasattr(layer.running_var, 'shape'): params = np.prod(layer.running_var.shape) ntb__params += params in_shape = input[0] flops = np.prod(in_shape.shape) if layer.affine: flops *= 2 return tb_params, ntb__params, flops
Example 18
Project: torch-toolbox Author: PistonY File: summary.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def _cac_linear(layer, input, output): ic, oc = layer.in_features, layer.out_features tb_params = 0 ntb__params = 0 flops = 0 if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): params = np.prod(layer.weight.shape) t, n = _cac_grad_params(params, layer.weight) tb_params += t ntb__params += n flops += (2 * ic - 1) * oc if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): params = np.prod(layer.bias.shape) t, n = _cac_grad_params(params, layer.bias) tb_params += t ntb__params += n flops += oc return tb_params, ntb__params, flops
Example 19
Project: A2C Author: lnpalmer File: models.py License: MIT License | 6 votes |
def ortho_weights(shape, scale=1.): """ PyTorch port of ortho_init from baselines.a2c.utils """ shape = tuple(shape) if len(shape) == 2: flat_shape = shape[1], shape[0] elif len(shape) == 4: flat_shape = (np.prod(shape[1:]), shape[0]) else: raise NotImplementedError a = np.random.normal(0., 1., flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) q = u if u.shape == flat_shape else v q = q.transpose().copy().reshape(shape) if len(shape) == 2: return torch.from_numpy((scale * q).astype(np.float32)) if len(shape) == 4: return torch.from_numpy((scale * q[:, :shape[1], :shape[2]]).astype(np.float32))
Example 20
Project: HardRLWithYoutube Author: MaxSobolMark File: utils.py License: MIT License | 6 votes |
def ortho_init(scale=1.0): def _ortho_init(shape, dtype, partition_info=None): #lasagne ortho init for tf shape = tuple(shape) if len(shape) == 2: flat_shape = shape elif len(shape) == 4: # assumes NHWC flat_shape = (np.prod(shape[:-1]), shape[-1]) else: raise NotImplementedError a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) q = u if u.shape == flat_shape else v # pick the one with the correct shape q = q.reshape(shape) return (scale * q[:shape[0], :shape[1]]).astype(np.float32) return _ortho_init
Example 21
Project: HardRLWithYoutube Author: MaxSobolMark File: mpi_adam_optimizer.py License: MIT License | 6 votes |
def compute_gradients(self, loss, var_list, **kwargs): grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs) grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) shapes = [v.shape.as_list() for g, v in grads_and_vars] sizes = [int(np.prod(s)) for s in shapes] num_tasks = self.comm.Get_size() buf = np.zeros(sum(sizes), np.float32) def _collect_grads(flat_grad): self.comm.Allreduce(flat_grad, buf, op=MPI.SUM) np.divide(buf, float(num_tasks), out=buf) return buf avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32) avg_flat_grad.set_shape(flat_grad.shape) avg_grads = tf.split(avg_flat_grad, sizes, axis=0) avg_grads_and_vars = [(tf.reshape(g, v.shape), v) for g, (_, v) in zip(avg_grads, grads_and_vars)] return avg_grads_and_vars
Example 22
Project: Att-ChemdNER Author: lingluodlut File: initializations.py License: Apache License 2.0 | 5 votes |
def orthogonal(shape, scale=1.1, name=None): ''' From Lasagne. Reference: Saxe et al., http://arxiv.org/abs/1312.6120 ''' flat_shape = (shape[0], np.prod(shape[1:])) a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) # pick the one with the correct shape q = u if u.shape == flat_shape else v q = q.reshape(shape) return K.variable(scale * q[:shape[0], :shape[1]], name=name)
Example 23
Project: Att-ChemdNER Author: lingluodlut File: theano_backend.py License: Apache License 2.0 | 5 votes |
def count_params(x): '''Returns the number of scalars in a tensor. Return: numpy integer. ''' return np.prod(x.shape.eval())
Example 24
Project: Att-ChemdNER Author: lingluodlut File: theano_backend.py License: Apache License 2.0 | 5 votes |
def prod(x, axis=None, keepdims=False): '''Multiply the values in a tensor, alongside the specified axis. ''' return T.prod(x, axis=axis, keepdims=keepdims)
Example 25
Project: Att-ChemdNER Author: lingluodlut File: theano_backend.py License: Apache License 2.0 | 5 votes |
def batch_flatten(x): '''Turn a n-D tensor into a 2D tensor where the first dimension is conserved. ''' # TODO: `keras_shape` inference. x = T.reshape(x, (x.shape[0], T.prod(x.shape) // x.shape[0])) return x
Example 26
Project: Traffic_sign_detection_YOLO Author: AmeyaWagh File: layer.py License: MIT License | 5 votes |
def __init__(self, *args): self._signature = list(args) self.type = list(args)[0] self.number = list(args)[1] self.w = dict() # weights self.h = dict() # placeholders self.wshape = dict() # weight shape self.wsize = dict() # weight size self.setup(*args[2:]) # set attr up self.present() for var in self.wshape: shp = self.wshape[var] size = np.prod(shp) self.wsize[var] = size
Example 27
Project: disentangling_conditional_gans Author: zalandoresearch File: networks.py License: MIT License | 5 votes |
def get_weight(shape, gain=np.sqrt(2), use_wscale=False, fan_in=None): if fan_in is None: fan_in = np.prod(shape[:-1]) std = gain / np.sqrt(fan_in) # He init if use_wscale: wscale = tf.constant(np.float32(std), name='wscale') return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale else: return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std)) #---------------------------------------------------------------------------- # Fully-connected layer.
Example 28
Project: disentangling_conditional_gans Author: zalandoresearch File: networks.py License: MIT License | 5 votes |
def dense(x, fmaps, gain=np.sqrt(2), use_wscale=False): if len(x.shape) > 2: x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) w = tf.cast(w, x.dtype) return tf.matmul(x, w) #---------------------------------------------------------------------------- # Convolutional layer.
Example 29
Project: disentangling_conditional_gans Author: zalandoresearch File: util_scripts.py License: MIT License | 5 votes |
def generate_interpolation_video(run_id, snapshot=None, grid_size=[1,1], image_shrink=1, image_zoom=1, duration_sec=60.0, smoothing_sec=1.0, mp4=None, mp4_fps=30, mp4_codec='libx265', mp4_bitrate='16M', random_seed=1000, minibatch_size=8): network_pkl = misc.locate_network_pkl(run_id, snapshot) if mp4 is None: mp4 = misc.get_id_string_for_network_pkl(network_pkl) + '-lerp.mp4' num_frames = int(np.rint(duration_sec * mp4_fps)) random_state = np.random.RandomState(random_seed) print('Loading network from "%s"...' % network_pkl) G, D, Gs = misc.load_network_pkl(run_id, snapshot) print('Generating latent vectors...') shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # [frame, image, channel, component] all_latents = random_state.randn(*shape).astype(np.float32) all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape), mode='wrap') all_latents /= np.sqrt(np.mean(np.square(all_latents))) # Frame generation func for moviepy. def make_frame(t): frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) latents = all_latents[frame_idx] labels = np.zeros([latents.shape[0], 0], np.float32) images = Gs.run(latents, labels, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8) grid = misc.create_image_grid(images, grid_size).transpose(1, 2, 0) # HWC if image_zoom > 1: grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0) if grid.shape[2] == 1: grid = grid.repeat(3, 2) # grayscale => RGB return grid # Generate video. import moviepy.editor # pip install moviepy result_subdir = misc.create_result_subdir(config.result_dir, config.desc) moviepy.editor.VideoClip(make_frame, duration=duration_sec).write_videofile(os.path.join(result_subdir, mp4), fps=mp4_fps, codec='libx264', bitrate=mp4_bitrate) open(os.path.join(result_subdir, '_done.txt'), 'wt').close() #---------------------------------------------------------------------------- # Generate MP4 video of training progress for a previous training run. # To run, uncomment the appropriate line in config.py and launch train.py.
Example 30
Project: dustmaps Author: gregreen File: bayestar.py License: GNU General Public License v2.0 | 5 votes |
def get_query_size(self, coords, mode='random_sample', return_flags=False, pct=None): # Check that the query mode is supported self._raise_on_mode(mode) # Validate percentile specification pct, scalar_pct = self._interpret_percentile(mode, pct) n_coords = np.prod(coords.shape, dtype=int) if mode == 'samples': n_samples = self._n_samples elif mode == 'percentile': if scalar_pct: n_samples = 1 else: n_samples = len(pct) else: n_samples = 1 if hasattr(coords.distance, 'kpc'): n_dists = 1 else: n_dists = self._n_distances return n_coords * n_samples * n_dists