Python numpy.isreal() Examples

The following are 30 code examples for showing how to use numpy.isreal(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module numpy , or try the search function .

Example 1
Project: pyberny   Author: jhrmnn   File: Math.py    License: Mozilla Public License 2.0 6 votes vote down vote up
def fit_cubic(y0, y1, g0, g1):
    """Fit cubic polynomial to function values and derivatives at x = 0, 1.

    Returns position and function value of minimum if fit succeeds. Fit does
    not succeeds if

    1. polynomial doesn't have extrema or
    2. maximum is from (0,1) or
    3. maximum is closer to 0.5 than minimum
    """
    a = 2 * (y0 - y1) + g0 + g1
    b = -3 * (y0 - y1) - 2 * g0 - g1
    p = np.array([a, b, g0, y0])
    r = np.roots(np.polyder(p))
    if not np.isreal(r).all():
        return None, None
    r = sorted(x.real for x in r)
    if p[0] > 0:
        maxim, minim = r
    else:
        minim, maxim = r
    if 0 < maxim < 1 and abs(minim - 0.5) > abs(maxim - 0.5):
        return None, None
    return minim, np.polyval(p, minim) 
Example 2
Project: lambda-packs   Author: ryfeus   File: ltisys.py    License: MIT License 6 votes vote down vote up
def _order_complex_poles(poles):
    """
    Check we have complex conjugates pairs and reorder P according to YT, ie
    real_poles, complex_i, conjugate complex_i, ....
    The lexicographic sort on the complex poles is added to help the user to
    compare sets of poles.
    """
    ordered_poles = np.sort(poles[np.isreal(poles)])
    im_poles = []
    for p in np.sort(poles[np.imag(poles) < 0]):
        if np.conj(p) in poles:
            im_poles.extend((p, np.conj(p)))

    ordered_poles = np.hstack((ordered_poles, im_poles))

    if poles.shape[0] != len(ordered_poles):
        raise ValueError("Complex poles must come with their conjugates")
    return ordered_poles 
Example 3
Project: safelife   Author: PartnershipOnAI   File: safelife_logger.py    License: Apache License 2.0 6 votes vote down vote up
def log_scalars(self, data, global_step=None, tag=None):
        """
        Log scalar values to tensorboard.

        Parameters
        ----------
        data : dict
            Dictionary of key/value pairs to log to tensorboard.
        tag : str or None

        """
        self.init_logdir()  # init if needed

        if not self.summary_writer:
            return
        tag = "" if tag is None else tag + '/'
        if global_step is None:
            global_step = self.cumulative_stats['training_steps']
        for key, val in data.items():
            if np.isreal(val) and np.isscalar(val):
                self.summary_writer.add_scalar(tag + key, val, global_step)
        self.summary_writer.flush() 
Example 4
Project: sonata   Author: AllenInstitute   File: test_types.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_edge_types(net):
    edge_types = net.edges.edge_types_table
    assert(edge_types is not None)
    assert(len(edge_types.edge_type_ids) == 11)
    assert(len(edge_types.columns) == 5)
    assert('template' in edge_types.columns)
    assert('delay' in edge_types.columns)
    assert(edge_types.to_dataframe().shape == (11, 5))
    assert(np.isreal(edge_types.column('delay').dtype))

    assert(1 in edge_types)
    edge_type1 = edge_types[1]
    assert(edge_type1['dynamics_params'] == 'instanteneousInh.json')
    assert(edge_type1['delay'] == 2.0)

    # check that row is being cached.
    mem_id = id(edge_type1)
    del edge_type1
    assert (mem_id == id(edge_types[1])) 
Example 5
Project: GraphicDesignPatternByPython   Author: Relph1119   File: ltisys.py    License: MIT License 6 votes vote down vote up
def _order_complex_poles(poles):
    """
    Check we have complex conjugates pairs and reorder P according to YT, ie
    real_poles, complex_i, conjugate complex_i, ....
    The lexicographic sort on the complex poles is added to help the user to
    compare sets of poles.
    """
    ordered_poles = np.sort(poles[np.isreal(poles)])
    im_poles = []
    for p in np.sort(poles[np.imag(poles) < 0]):
        if np.conj(p) in poles:
            im_poles.extend((p, np.conj(p)))

    ordered_poles = np.hstack((ordered_poles, im_poles))

    if poles.shape[0] != len(ordered_poles):
        raise ValueError("Complex poles must come with their conjugates")
    return ordered_poles 
Example 6
def get_fixed_point(I=0., eps=0.1, a=2.0):
    """Computes the fixed point of the FitzHugh Nagumo model
    as a function of the input current I.

    We solve the 3rd order poylnomial equation:
    v**3 + V + a - I0 = 0

    Args:
        I: Constant input [mV]
        eps: Inverse time constant of the recovery variable w [1/ms]
        a: Offset of the w-nullcline [mV]

    Returns:
        tuple: (v_fp, w_fp) fixed point of the equations
    """

    # Use poly1d function from numpy to compute the
    # roots of 3rd order polynomial
    P = np.poly1d([1, 0, 1, (a - I)], variable="x")

    # take only the real root
    v_fp = np.real(P.r[np.isreal(P.r)])[0]
    w_fp = 2. * v_fp + a

    return (v_fp, w_fp) 
Example 7
Project: PyBloqs   Author: man-group   File: table_formatters.py    License: GNU Lesser General Public License v2.1 6 votes vote down vote up
def _modify_dataframe(self, df):
        """Add row to dataframe, containing numbers aggregated with self.operator."""
        if self.total_columns == []:
            columns = df.columns
        else:
            columns = self.total_columns
        if self.operator is not OP_NONE:
            df_calculated = df[columns]
            last_row = self.operator(df_calculated[df_calculated.applymap(np.isreal)])
            last_row = last_row.fillna(0.)
            last_row = last_row.append(pd.Series('', index=df.columns.difference(last_row.index)))
        else:
            last_row = pd.Series('', index=df.columns)
        last_row.name = self.row_name
        # Appending kills index name, save now and restore after appending
        index_name = df.index.name
        df = df.append(last_row)
        df.index.name = index_name
        return df 
Example 8
def validate_gibbs_parameters(alpha1, alpha2, beta, restarts,
                              draws_per_restart, burnin, delay):
    '''Return `True` if params numerically acceptable. See `gibbs` for docs.'''
    real_vals = [alpha1, alpha2, beta]
    int_vals = [restarts, draws_per_restart, burnin, delay]
    # Check everything is real.
    if all(np.isreal(val) for val in real_vals + int_vals):
        # Check that integer values are some type of int.
        int_check = all(isinstance(val, (int, np.int32, np.int64)) for val in
                        int_vals)
        # All integer values must be > 0.
        pos_int = all(val > 0 for val in int_vals)
        # All real values must be non-negative.
        non_neg = all(val >= 0 for val in real_vals)
        return int_check and pos_int and non_neg and real_vals
    else:  # Failed to be all numeric values.
        False 
Example 9
Project: aitom   Author: xulabs   File: ang_loc.py    License: GNU General Public License v3.0 6 votes vote down vote up
def rotation_matrix_zyz_normalized_angle(rm):

    assert(all(N.isreal(rm.flatten())));     assert(rm.shape == (3,3));

    cos_theta = rm[2, 2]
    if N.abs(cos_theta) > 1.0:
        # warning(sprintf('cos_theta %g', cos_theta));
        cos_theta = N.sign(cos_theta);

    theta = N.arctan2(N.sqrt(1.0 - (cos_theta*cos_theta) ), cos_theta);

    if N.abs(cos_theta) < (1.0 - (1e-10)) :          # use a small epslon to increase numerical stability when abs(cos_theta) is very close to 1!!!!
        phi = N.arctan2(rm[2,1], rm[2,0]);
        psi_t = N.arctan2(rm[1,2], -rm[0,2]);
    else:
        theta = 0.0
        phi = 0.0
        psi_t = N.arctan2(rm[0,1], rm[1,1])

    ang = N.array([phi, theta, psi_t], dtype=N.float)

    return ang 
Example 10
Project: ray   Author: ray-project   File: vector_env.py    License: Apache License 2.0 6 votes vote down vote up
def vector_step(self, actions):
        obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
        for i in range(self.num_envs):
            obs, r, done, info = self.envs[i].step(actions[i])
            if not np.isscalar(r) or not np.isreal(r) or not np.isfinite(r):
                raise ValueError(
                    "Reward should be finite scalar, got {} ({}). "
                    "Actions={}.".format(r, type(r), actions[i]))
            if type(info) is not dict:
                raise ValueError("Info should be a dict, got {} ({})".format(
                    info, type(info)))
            obs_batch.append(obs)
            rew_batch.append(r)
            done_batch.append(done)
            info_batch.append(info)
        return obs_batch, rew_batch, done_batch, info_batch 
Example 11
Project: ray   Author: ray-project   File: vector_env.py    License: Apache License 2.0 6 votes vote down vote up
def vector_step(self, actions):
        obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
        for i in range(self.num_envs):
            obs, r, done, info = self.envs[i].step(actions[i])
            if not np.isscalar(r) or not np.isreal(r) or not np.isfinite(r):
                raise ValueError(
                    "Reward should be finite scalar, got {} ({}). "
                    "Actions={}.".format(r, type(r), actions[i]))
            if type(info) is not dict:
                raise ValueError("Info should be a dict, got {} ({})".format(
                    info, type(info)))
            obs_batch.append(obs)
            rew_batch.append(r)
            done_batch.append(done)
            info_batch.append(info)
        return obs_batch, rew_batch, done_batch, info_batch 
Example 12
Project: pytorch-complex-tensor   Author: williamFalcon   File: complex_tensor.py    License: MIT License 6 votes vote down vote up
def __truediv__(self, other):
        real = self.real.clone()
        imag = self.imag.clone()

        # given a real tensor
        if isinstance(other, torch.Tensor) and type(other) is not ComplexTensor:
            raise NotImplementedError

        # given a complex tensor
        elif type(other) is ComplexTensor:
            raise NotImplementedError

        # given a real scalar
        elif np.isreal(other):
            real = real / other
            imag = imag / other

        # given a complex scalar
        else:
            raise NotImplementedError

        return self.__graph_copy__(real, imag) 
Example 13
Project: Splunking-Crime   Author: nccgroup   File: ltisys.py    License: GNU Affero General Public License v3.0 6 votes vote down vote up
def _order_complex_poles(poles):
    """
    Check we have complex conjugates pairs and reorder P according to YT, ie
    real_poles, complex_i, conjugate complex_i, ....
    The lexicographic sort on the complex poles is added to help the user to
    compare sets of poles.
    """
    ordered_poles = np.sort(poles[np.isreal(poles)])
    im_poles = []
    for p in np.sort(poles[np.imag(poles) < 0]):
        if np.conj(p) in poles:
            im_poles.extend((p, np.conj(p)))

    ordered_poles = np.hstack((ordered_poles, im_poles))

    if poles.shape[0] != len(ordered_poles):
        raise ValueError("Complex poles must come with their conjugates")
    return ordered_poles 
Example 14
Project: treetime   Author: neherlab   File: distribution.py    License: MIT License 6 votes vote down vote up
def __call__(self, x):

        if isinstance(x, Iterable):
            valid_idxs = (x > self._xmin-TINY_NUMBER) & (x < self._xmax+TINY_NUMBER)
            res = np.ones_like (x, dtype=float) * (BIG_NUMBER+self.peak_val)
            tmp_x = np.copy(x[valid_idxs])
            tmp_x[tmp_x<self._xmin+TINY_NUMBER] = self._xmin+TINY_NUMBER
            tmp_x[tmp_x>self._xmax-TINY_NUMBER] = self._xmax-TINY_NUMBER
            res[valid_idxs] = self._peak_val + self._func(tmp_x)
            return res

        elif np.isreal(x):
            if x < self._xmin or x > self._xmax:
                return BIG_NUMBER+self.peak_val
            # x is within interpolation range
            elif self._delta == True:
                return self._peak_val
            else:
                return self._peak_val + self._func(x)
        else:
            raise TypeError("Wrong type: should be float or array") 
Example 15
Project: pyrpl   Author: lneuhaus   File: attribute_widgets.py    License: GNU General Public License v3.0 6 votes vote down vote up
def _set_widget_value(self, new_value, transform_magnitude=lambda data :
    20. * np.log10(np.abs(data) + sys.float_info.epsilon)):
        if new_value is None:
            return
        x, y = new_value
        shape = np.shape(y)
        if len(shape) > 2:
            raise ValueError("Data cannot be larger than 2 "
                             "dimensional")
        if len(shape) == 1:
            y = [y]
        self._set_real(np.isreal(y).all())
        for i, values in enumerate(y):
            self._display_curve_index(x, values, i, transform_magnitude=transform_magnitude)
        while (i + 1 < len(self.curves)):  # delete remaining curves
            i += 1
            self.curves[i].hide() 
Example 16
Project: pyABC   Author: ICB-DCM   File: local_transition.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def fit(self, X, w):
        if len(X) == 0:
            raise NotEnoughParticles("Fitting not possible.")
        self.X_arr = X.values

        ctree = cKDTree(X)
        _, indices = ctree.query(X, k=min(self.k + 1, X.shape[0]))

        covs, inv_covs, dets = list(zip(*[self._cov_and_inv(n, indices)
                                    for n in range(X.shape[0])]))
        self.covs = np.array(covs)
        self.inv_covs = np.array(inv_covs)
        self.determinants = np.array(dets)

        self.normalization = np.sqrt(
            (2 * np.pi) ** self.X_arr.shape[1] * self.determinants)

        if not np.isreal(self.normalization).all():
            raise Exception("Normalization not real")
        self.normalization = np.real(self.normalization) 
Example 17
Project: twitter-stock-recommendation   Author: alvarobartt   File: _base.py    License: MIT License 6 votes vote down vote up
def _validate_converted_limits(self, limit, convert):
        """
        Raise ValueError if converted limits are non-finite.

        Note that this function also accepts None as a limit argument.

        Returns
        -------
        The limit value after call to convert(), or None if limit is None.

        """
        if limit is not None:
            converted_limit = convert(limit)
            if (isinstance(converted_limit, float) and
                    (not np.isreal(converted_limit) or
                        not np.isfinite(converted_limit))):
                raise ValueError("Axis limits cannot be NaN or Inf")
            return converted_limit 
Example 18
Project: scaper   Author: justinsalamon   File: util.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def is_real_number(num):
    '''
    Check if a value is a real scalar by aggregating several numpy checks.

    Parameters
    ----------
    num : any type
        The parameter to check

    Returns
    ------
    check : bool
        True if ```num``` is a real scalar, False otherwise.

    '''

    if (not np.isreal(num) or
            not np.isrealobj(num) or
            not np.isscalar(num)):
        return False
    else:
        return True 
Example 19
Project: pyberny   Author: jhrmnn   File: Math.py    License: Mozilla Public License 2.0 5 votes vote down vote up
def fit_quartic(y0, y1, g0, g1):
    """Fit constrained quartic polynomial to function values and erivatives at x = 0,1.

    Returns position and function value of minimum or None if fit fails or has
    a maximum. Quartic polynomial is constrained such that it's 2nd derivative
    is zero at just one point. This ensures that it has just one local
    extremum.  No such or two such quartic polynomials always exist. From the
    two, the one with lower minimum is chosen.
    """

    def g(y0, y1, g0, g1, c):
        a = c + 3 * (y0 - y1) + 2 * g0 + g1
        b = -2 * c - 4 * (y0 - y1) - 3 * g0 - g1
        return np.array([a, b, c, g0, y0])

    def quart_min(p):
        r = np.roots(np.polyder(p))
        is_real = np.isreal(r)
        if is_real.sum() == 1:
            minim = r[is_real][0].real
        else:
            minim = r[(r == max(-abs(r))) | (r == -max(-abs(r)))][0].real
        return minim, np.polyval(p, minim)

    # discriminant of d^2y/dx^2=0
    D = -((g0 + g1) ** 2) - 2 * g0 * g1 + 6 * (y1 - y0) * (g0 + g1) - 6 * (y1 - y0) ** 2
    if D < 1e-11:
        return None, None
    else:
        m = -5 * g0 - g1 - 6 * y0 + 6 * y1
        p1 = g(y0, y1, g0, g1, 0.5 * (m + np.sqrt(2 * D)))
        p2 = g(y0, y1, g0, g1, 0.5 * (m - np.sqrt(2 * D)))
        if p1[0] < 0 and p2[0] < 0:
            return None, None
        [minim1, minval1] = quart_min(p1)
        [minim2, minval2] = quart_min(p2)
        if minval1 < minval2:
            return minim1, minval1
        else:
            return minim2, minval2 
Example 20
Project: python-control   Author: python-control   File: margins.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def phase_crossover_frequencies(sys):
    """Compute frequencies and gains at intersections with real axis
    in Nyquist plot.

    Call as:
        omega, gain = phase_crossover_frequencies()

    Returns
    -------
    omega: 1d array of (non-negative) frequencies where Nyquist plot
    intersects the real axis

    gain: 1d array of corresponding gains

    Examples
    --------
    >>> tf = TransferFunction([1], [1, 2, 3, 4])
    >>> PhaseCrossoverFrequenies(tf)
    (array([ 1.73205081,  0.        ]), array([-0.5 ,  0.25]))
    """

    # Convert to a transfer function
    tf = xferfcn._convert_to_transfer_function(sys)

    # if not siso, fall back to (0,0) element
    #! TODO: should add a check and warning here
    num = tf.num[0][0]
    den = tf.den[0][0]

    # Compute frequencies that we cross over the real axis
    numj = (1.j)**np.arange(len(num)-1,-1,-1)*num
    denj = (-1.j)**np.arange(len(den)-1,-1,-1)*den
    allfreq = np.roots(np.imag(np.polymul(numj,denj)))
    realfreq = np.real(allfreq[np.isreal(allfreq)])
    realposfreq = realfreq[realfreq >= 0.]

    # using real() to avoid rounding errors and results like 1+0j
    # it would be nice to have a vectorized version of self.evalfr here
    gain = np.real(np.asarray([tf._evalfr(f)[0][0] for f in realposfreq]))

    return realposfreq, gain 
Example 21
Project: lambda-packs   Author: ryfeus   File: filter_design.py    License: MIT License 5 votes vote down vote up
def _nearest_real_complex_idx(fro, to, which):
    """Get the next closest real or complex element based on distance"""
    assert which in ('real', 'complex')
    order = np.argsort(np.abs(fro - to))
    mask = np.isreal(fro[order])
    if which == 'complex':
        mask = ~mask
    return order[np.where(mask)[0][0]] 
Example 22
Project: k-means-plus-plus-pandas   Author: jackmaney   File: cluster.py    License: MIT License 5 votes vote down vote up
def _is_numeric(self, col):
        return all(np.isreal(self.data_frame[col])) and not any(np.isnan(self.data_frame[col])) 
Example 23
Project: k-means-plus-plus-pandas   Author: jackmaney   File: cluster.py    License: MIT License 5 votes vote down vote up
def _is_numeric(self, col):
        return all(np.isreal(self.data_frame[col])) and not any(np.isnan(self.data_frame[col])) 
Example 24
Project: causalimpact   Author: dafiti   File: main.py    License: Apache License 2.0 5 votes vote down vote up
def _format_input_data(self, data):
        """
        Validates and formats input data.

        Args
        ----
          data: `numpy.array` or `pandas.DataFrame`.

        Returns
        -------
          data: pandas DataFrame.
              Validated data to be used in Causal Impact algorithm.

        Raises
        ------
          ValueError: if input `data` is non-convertible to pandas DataFrame.
                      if input `data` has non-numeric values.
                      if input `data` has less than 3 points.
                      if input covariates have NAN values.
        """
        if not isinstance(data, pd.DataFrame):
            try:
                data = pd.DataFrame(data)
            except ValueError:
                raise ValueError(
                    'Could not transform input data to pandas DataFrame.'
                )
        self._validate_y(data.iloc[:, 0])
        # Must contain only numeric values
        if not data.applymap(np.isreal).values.all():
            raise ValueError('Input data must contain only numeric values.')
        # Covariates cannot have NAN values
        if data.shape[1] > 1:
            if data.iloc[:, 1:].isna().values.any():
                raise ValueError('Input data cannot have NAN values.')
        # If index is a string of dates, try to convert it to datetimes which helps
        # in plotting.
        data = self._convert_index_to_datetime(data)
        return data 
Example 25
Project: Computable   Author: ktraunmueller   File: math.py    License: MIT License 5 votes vote down vote up
def is_psd(m):
    eigvals = linalg.eigvals(m)
    return np.isreal(eigvals).all() and (eigvals >= 0).all() 
Example 26
Project: trax   Author: google   File: math_ops.py    License: Apache License 2.0 5 votes vote down vote up
def isreal(x):
  return array_ops.imag(x) == 0 
Example 27
Project: sonata   Author: AllenInstitute   File: test_compartment_writer.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_one_compartment_report():
    population = 'p1'
    output_file = tempfile.mkstemp(suffix='h5')[1]

    cr = CompartmentReport(output_file, mode='w', default_population=population,
                           tstart=0.0, tstop=100.0, dt=0.1)
    cr.add_cell(node_id=0, element_ids=[0], element_pos=[0.0])
    for i in range(1000):
        cr.record_cell(0, [i/100.0], tstep=i)

    cr.close()

    report_h5 = h5py.File(output_file, 'r')
    report_grp = report_h5['/report/{}'.format(population)]
    assert('data' in report_grp)
    data_ds = report_grp['data'][()]
    assert(report_grp['data'].size == 1000)
    assert(np.isreal(data_ds.dtype))
    assert(data_ds[0] == 0.00)
    assert(data_ds[-1] == 9.99)


    assert('mapping' in report_grp)
    mapping_grp = report_grp['mapping']
    assert(all(mapping_grp['element_ids'][()] == [0]))
    assert(mapping_grp['element_pos'][()] == [0.0])
    assert(mapping_grp['index_pointer'][()].size == 2)
    assert(mapping_grp['node_ids'][()] == [0])
    assert(np.allclose(mapping_grp['time'][()], [0.0, 100.0, 0.1]))
    os.remove(output_file) 
Example 28
Project: sonata   Author: AllenInstitute   File: test_compartment_writer.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_multi_compartment_report():
    population = 'cortical'
    output_file = tempfile.mkstemp(suffix='h5')[1]
    n_elements = 50

    cr = CompartmentReport(output_file, mode='w', default_population=population,
                           tstart=0.0, tstop=100.0, dt=0.1)
    cr.add_cell(node_id=0, element_ids=np.arange(n_elements), element_pos=[0.5]*n_elements)
    cr.initialize()
    for i in range(1000):
        cr.record_cell(0, [i+j for j in range(n_elements)], tstep=i)

    cr.close()

    report_h5 = h5py.File(output_file, 'r')
    report_grp = report_h5['/report/{}'.format(population)]
    assert('data' in report_grp)
    data_ds = report_grp['data'][()]
    assert(report_grp['data'].shape == (1000, n_elements))
    assert(np.isreal(data_ds.dtype))
    assert(data_ds[0, 0] == 0.0)
    assert(data_ds[999, n_elements-1] == 999.0+n_elements-1)

    assert('mapping' in report_grp)
    mapping_grp = report_grp['mapping']
    assert(np.allclose(mapping_grp['element_ids'][()], np.arange(n_elements)))
    assert(np.allclose(mapping_grp['element_pos'][()], [0.5]*n_elements))
    assert(mapping_grp['index_pointer'][()].size == 2)
    assert(mapping_grp['node_ids'][()] == [0])
    assert(np.allclose(mapping_grp['time'][()], [0.0, 100.0, 0.1]))
    os.remove(output_file) 
Example 29
Project: sonata   Author: AllenInstitute   File: test_compartment_writer.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_multi_cell_report(buffer_size=0):
    cells = [(0, 10), (1, 50), (2, 100), (3, 1), (4, 200)]
    total_elements = sum(n_elements for _, n_elements in cells)
    rank_cells = [c for c in cells[rank::nhosts]]
    output_file = os.path.join(cpath, 'output/multi_compartment_report.h5')
    population = 'cortical'

    cr = CompartmentReport(output_file, mode='w', default_population=population,
                           tstart=0.0, tstop=100.0, dt=0.1, variable='mebrane_potential', units='mV',
                           buffer_size=buffer_size)
    for node_id, n_elements in rank_cells:
        cr.add_cell(node_id=node_id, element_ids=np.arange(n_elements), element_pos=np.zeros(n_elements))

    for i in range(1000):
        for node_id, n_elements in rank_cells:
            cr.record_cell(node_id, [node_id+i/1000.0]*n_elements, tstep=i)
    cr.close()

    if rank == 0:
        report_h5 = h5py.File(output_file, 'r')
        report_grp = report_h5['/report/{}'.format(population)]
        assert('data' in report_grp)
        data_ds = report_grp['data'][()]
        assert(report_grp['data'].shape == (1000, total_elements))
        assert(np.isreal(data_ds.dtype))

        assert('mapping' in report_grp)
        mapping_grp = report_grp['mapping']
        assert(mapping_grp['element_ids'].size == total_elements)
        assert(mapping_grp['element_pos'].size == total_elements)
        assert(mapping_grp['index_pointer'].size == 6)
        assert(np.all(np.sort(mapping_grp['node_ids'][()]) == np.arange(5)))
        assert(np.allclose(mapping_grp['time'][()], [0.0, 100.0, 0.1]))

        os.remove(output_file)
    barrier() 
Example 30
Project: sonata   Author: AllenInstitute   File: test_compartment_writer.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_block_record():
    cells = [(0, 10), (1, 50), (2, 100), (3, 1), (4, 200)]
    total_elements = sum(n_elements for _, n_elements in cells)
    rank_cells = [c for c in cells[rank::nhosts]]
    output_file = os.path.join(cpath, 'output/multi_compartment_report.h5')
    population = 'cortical'

    cr = CompartmentReport(output_file, mode='w', default_population=population,
                           tstart=0.0, tstop=100.0, dt=0.1, variable='mebrane_potential', units='mV')
    for node_id, n_elements in rank_cells:
        cr.add_cell(node_id=node_id, element_ids=np.arange(n_elements), element_pos=np.zeros(n_elements))

    for node_id, n_elements in rank_cells:
        cr.record_cell_block(node_id, np.full((1000, n_elements), fill_value=node_id+1), beg_step=0, end_step=1000)

    cr.close()

    if rank == 0:
        report_h5 = h5py.File(output_file, 'r')
        report_grp = report_h5['/report/{}'.format(population)]
        assert('data' in report_grp)
        data_ds = report_grp['data'][()]
        assert(report_grp['data'].shape == (1000, total_elements))
        assert(np.isreal(data_ds.dtype))

        assert('mapping' in report_grp)
        mapping_grp = report_grp['mapping']
        assert(mapping_grp['element_ids'].size == total_elements)
        assert(mapping_grp['element_pos'].size == total_elements)
        assert(mapping_grp['index_pointer'].size == 6)
        assert(np.all(np.sort(mapping_grp['node_ids'][()]) == np.arange(5)))
        assert(np.allclose(mapping_grp['time'][()], [0.0, 100.0, 0.1]))

        os.remove(output_file)
    barrier()