Python numpy.imag() Examples

The following are 30 code examples of numpy.imag(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module numpy , or try the search function .
Example #1
Source File: matrixtools.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def is_hermitian(mx, TOL=1e-9):
    """
    Test whether mx is a hermitian matrix.

    Parameters
    ----------
    mx : numpy array
        Matrix to test.

    TOL : float, optional
        Tolerance on absolute magitude of elements.

    Returns
    -------
    bool
        True if mx is hermitian, otherwise False.
    """
    (m, n) = mx.shape
    for i in range(m):
        if abs(mx[i, i].imag) > TOL: return False
        for j in range(i + 1, n):
            if abs(mx[i, j] - mx[j, i].conjugate()) > TOL: return False
    return True 
Example #2
Source File: __init__.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def safe_bulk_eval_compact_polys(vtape, ctape, paramvec, dest_shape):
    """Typechecking wrapper for :function:`bulk_eval_compact_polys`.

    The underlying method has two implementations: one for real-valued
    `ctape`, and one for complex-valued. This wrapper will dynamically
    dispatch to the appropriate implementation method based on the
    type of `ctape`. If the type of `ctape` is known prior to calling,
    it's slightly faster to call the appropriate implementation method
    directly; if not.
    """
    if _np.iscomplexobj(ctape):
        ret = bulk_eval_compact_polys_complex(vtape, ctape, paramvec, dest_shape)
        im_norm = _np.linalg.norm(_np.imag(ret))
        if im_norm > 1e-6:
            print("WARNING: norm(Im part) = {:g}".format(im_norm))
    else:
        ret = bulk_eval_compact_polys(vtape, ctape, paramvec, dest_shape)
    return _np.real(ret) 
Example #3
Source File: photonics.py    From nevergrad with MIT License 6 votes vote down vote up
def creneau(k0: float, a0: float, pol: float, e1: float, e2: float, a: float, n: int, x0: float) -> tp.Tuple[np.ndarray, np.ndarray]:
    nmod = int(n / 2)
    alpha = np.diag(a0 + 2 * np.pi * np.arange(-nmod, nmod + 1))
    if pol == 0:
        M = alpha * alpha - k0 * k0 * marche(e1, e2, a, n, x0)
        L, E = np.linalg.eig(M)
        L = np.sqrt(-L + 0j)
        L = (1 - 2 * (np.imag(L) < -1e-15)) * L
        P = np.block([[E], [np.matmul(E, np.diag(L))]])
    else:
        U = marche(1 / e1, 1 / e2, a, n, x0)
        T = np.linalg.inv(U)
        M = (
            np.matmul(
                np.matmul(np.matmul(T, alpha), np.linalg.inv(marche(e1, e2, a, n, x0))),
                alpha,
            )
            - k0 * k0 * T
        )
        L, E = np.linalg.eig(M)
        L = np.sqrt(-L + 0j)
        L = (1 - 2 * (np.imag(L) < -1e-15)) * L
        P = np.block([[E], [np.matmul(np.matmul(U, E), np.diag(L))]])
    return P, L 
Example #4
Source File: reportableqty.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def infidelity_diff(self, constant_value):
        """
        Returns a ReportableQty that is the (element-wise in the vector case)
        difference between `constant_value` and this one given by:

        `1.0 - Re(conjugate(constant_value) * self )`
        """
        # let diff(x) = 1.0 - Re(const.C * x) = 1.0 - (const.re * x.re + const.im * x.im)
        # so d(diff)/dx.re = -const.re, d(diff)/dx.im = -const.im
        # diff(x + dx) = diff(x) + d(diff)/dx * dx
        # diff(x + dx) - diff(x) =  - (const.re * dx.re + const.im * dx.im)
        v = 1.0 - _np.real(_np.conjugate(constant_value) * self.value)
        if self.has_eb():
            eb = abs(_np.real(constant_value) * _np.real(self.errbar)
                     + _np.imag(constant_value) * _np.real(self.errbar))
            return ReportableQty(v, eb, self.nonMarkovianEBs)
        else:
            return ReportableQty(v) 
Example #5
Source File: circuit.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def _num_to_rqc_str(num):
    """Convert float to string to be included in RQC quil DEFGATE block
    (as written by _np_to_quil_def_str)."""
    num = _np.complex(_np.real_if_close(num))
    if _np.imag(num) == 0:
        output = str(_np.real(num))
        return output
    else:
        real_part = _np.real(num)
        imag_part = _np.imag(num)
        if imag_part < 0:
            sgn = '-'
            imag_part = imag_part * -1
        elif imag_part > 0:
            sgn = '+'
        else:
            assert False
        return '{}{}{}i'.format(real_part, sgn, imag_part) 
Example #6
Source File: _ode.py    From lambda-packs with MIT License 6 votes vote down vote up
def _wrap_jac(self, t, y, *jac_args):
        # jac is the complex Jacobian computed by the user-defined function.
        jac = self.cjac(*((t, y[::2] + 1j * y[1::2]) + jac_args))

        # jac_tmp is the real version of the complex Jacobian.  Each complex
        # entry in jac, say 2+3j, becomes a 2x2 block of the form
        #     [2 -3]
        #     [3  2]
        jac_tmp = zeros((2 * jac.shape[0], 2 * jac.shape[1]))
        jac_tmp[1::2, 1::2] = jac_tmp[::2, ::2] = real(jac)
        jac_tmp[1::2, ::2] = imag(jac)
        jac_tmp[::2, 1::2] = -jac_tmp[1::2, ::2]

        ml = getattr(self._integrator, 'ml', None)
        mu = getattr(self._integrator, 'mu', None)
        if ml is not None or mu is not None:
            # Jacobian is banded.  The user's Jacobian function has computed
            # the complex Jacobian in packed format.  The corresponding
            # real-valued version has every other column shifted up.
            jac_tmp = _transform_banded_jac(jac_tmp)

        return jac_tmp 
Example #7
Source File: slowreplib.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def __str__(self):
        def fmt(x):
            if abs(_np.imag(x)) > 1e-6:
                if abs(_np.real(x)) > 1e-6: return "(%.3f+%.3fj)" % (x.real, x.imag)
                else: return "(%.3fj)" % x.imag
            else: return "%.3f" % x.real

        termstrs = []
        sorted_keys = sorted(list(self.keys()))
        for k in sorted_keys:
            vinds = self._int_to_vinds(k)
            varstr = ""; last_i = None; n = 0
            for i in sorted(vinds):
                if i == last_i: n += 1
                elif last_i is not None:
                    varstr += "x%d%s" % (last_i, ("^%d" % n) if n > 1 else "")
                last_i = i
            if last_i is not None:
                varstr += "x%d%s" % (last_i, ("^%d" % n) if n > 1 else "")
            #print("DB: vinds = ",vinds, " varstr = ",varstr)
            if abs(self[k]) > 1e-4:
                termstrs.append("%s%s" % (fmt(self[k]), varstr))
        if len(termstrs) > 0:
            return " + ".join(termstrs)
        else: return "0" 
Example #8
Source File: oplessmodel.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def bulk_fill_probs(self, mxToFill, evalTree, clipTo=None, check=False, comm=None):
        if False and evalTree.cache:  # TEST (disabled)
            cpolys = evalTree.cache
            ps = _safe_bulk_eval_compact_polys(cpolys[0], cpolys[1], self._paramvec, (evalTree.num_final_elements(),))
            assert(_np.linalg.norm(_np.imag(ps)) < 1e-6)
            ps = _np.real(ps)
            if clipTo is not None: ps = _np.clip(ps, clipTo[0], clipTo[1])
            mxToFill[:] = ps
        else:
            for i, c in enumerate(evalTree):
                cache = evalTree.cache[i] if evalTree.cache else None
                probs = self.probs(c, clipTo, cache)
                elInds = _slct.indices(evalTree.element_indices[i]) \
                    if isinstance(evalTree.element_indices[i], slice) else evalTree.element_indices[i]
                for k, outcome in zip(elInds, evalTree.outcomes[i]):
                    mxToFill[k] = probs[outcome] 
Example #9
Source File: ltisys.py    From lambda-packs with MIT License 6 votes vote down vote up
def _order_complex_poles(poles):
    """
    Check we have complex conjugates pairs and reorder P according to YT, ie
    real_poles, complex_i, conjugate complex_i, ....
    The lexicographic sort on the complex poles is added to help the user to
    compare sets of poles.
    """
    ordered_poles = np.sort(poles[np.isreal(poles)])
    im_poles = []
    for p in np.sort(poles[np.imag(poles) < 0]):
        if np.conj(p) in poles:
            im_poles.extend((p, np.conj(p)))

    ordered_poles = np.hstack((ordered_poles, im_poles))

    if poles.shape[0] != len(ordered_poles):
        raise ValueError("Complex poles must come with their conjugates")
    return ordered_poles 
Example #10
Source File: matrixtools.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def complex_compare(a, b):
    """
    Comparison function for complex numbers that compares real part, then
    imaginary part.

    Parameters
    ----------
    a,b : complex

    Returns
    -------
    -1 if a < b
     0 if a == b
    +1 if a > b
    """
    if a.real < b.real: return -1
    elif a.real > b.real: return 1
    elif a.imag < b.imag: return -1
    elif a.imag > b.imag: return 1
    else: return 0 
Example #11
Source File: matrixtools.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def safereal(A, inplace=False, check=False):
    """
    Returns the real-part of `A` correctly when `A` is either a dense array or
    a sparse matrix
    """
    if check:
        assert(safenorm(A, 'imag') < 1e-6), "Check failed: taking real-part of matrix w/nonzero imaginary part"
    if _sps.issparse(A):
        if _sps.isspmatrix_csr(A):
            if inplace:
                ret = _sps.csr_matrix((_np.real(A.data), A.indices, A.indptr), shape=A.shape, dtype='d')
            else:  # copy
                ret = _sps.csr_matrix((_np.real(A.data).copy(), A.indices.copy(),
                                       A.indptr.copy()), shape=A.shape, dtype='d')
            ret.eliminate_zeros()
            return ret
        else:
            raise NotImplementedError("safereal() doesn't work with %s matrices yet" % str(type(A)))
    else:
        return _np.real(A) 
Example #12
Source File: matrixtools.py    From pyGSTi with Apache License 2.0 6 votes vote down vote up
def safeimag(A, inplace=False, check=False):
    """
    Returns the imaginary-part of `A` correctly when `A` is either a dense array
    or a sparse matrix
    """
    if check:
        assert(safenorm(A, 'real') < 1e-6), "Check failed: taking imag-part of matrix w/nonzero real part"
    if _sps.issparse(A):
        if _sps.isspmatrix_csr(A):
            if inplace:
                ret = _sps.csr_matrix((_np.imag(A.data), A.indices, A.indptr), shape=A.shape, dtype='d')
            else:  # copy
                ret = _sps.csr_matrix((_np.imag(A.data).copy(), A.indices.copy(),
                                       A.indptr.copy()), shape=A.shape, dtype='d')
            ret.eliminate_zeros()
            return ret
        else:
            raise NotImplementedError("safereal() doesn't work with %s matrices yet" % str(type(A)))
    else:
        return _np.imag(A) 
Example #13
Source File: Kaiser 1962 - CaF2.py    From refractiveindex.info-scripts with GNU General Public License v3.0 6 votes vote down vote up
def SaveYML(w_um, RefInd, filename, references='', comments=''):
    
    header = np.empty(9, dtype=object)
    header[0] = '# this file is part of refractiveindex.info database'
    header[1] = '# refractiveindex.info database is in the public domain'
    header[2] = '# copyright and related rights waived via CC0 1.0'
    header[3] = ''
    header[4] = 'REFERENCES:' + references
    header[5] = 'COMMENTS:' + comments
    header[6] = 'DATA:'
    header[7] = '  - type: tabulated nk'
    header[8] = '    data: |'
    
    export = np.column_stack((w_um, np.real(RefInd), np.imag(RefInd)))
    np.savetxt(filename, export, fmt='%4.2f %#.4g %#.4g', delimiter=' ', header='\n'.join(header), comments='',newline='\n        ')
    return

###############################################################################

## Wavelengths to sample ## 
Example #14
Source File: Tsuda 2018 - PMMA (BB model).py    From refractiveindex.info-scripts with GNU General Public License v3.0 6 votes vote down vote up
def SaveYML(w_um, RefInd, filename, references='', comments=''):
    
    header = np.empty(9, dtype=object)
    header[0] = '# this file is part of refractiveindex.info database'
    header[1] = '# refractiveindex.info database is in the public domain'
    header[2] = '# copyright and related rights waived via CC0 1.0'
    header[3] = ''
    header[4] = 'REFERENCES:' + references
    header[5] = 'COMMENTS:' + comments
    header[6] = 'DATA:'
    header[7] = '  - type: tabulated nk'
    header[8] = '    data: |'
    
    export = np.column_stack((w_um, np.real(RefInd), np.imag(RefInd)))
    np.savetxt(filename, export, fmt='%4.2f %#.4g %#.3e', delimiter=' ', header='\n'.join(header), comments='',newline='\n        ')
    return

###############################################################################

## Wavelengths to sample ## 
Example #15
Source File: Zhang 1998 - Kapton.py    From refractiveindex.info-scripts with GNU General Public License v3.0 6 votes vote down vote up
def SaveYML(w_um, RefInd, filename, references='', comments=''):
    
    header = np.empty(9, dtype=object)
    header[0] = '# this file is part of refractiveindex.info database'
    header[1] = '# refractiveindex.info database is in the public domain'
    header[2] = '# copyright and related rights waived via CC0 1.0'
    header[3] = ''
    header[4] = 'REFERENCES:' + references
    header[5] = 'COMMENTS:' + comments
    header[6] = 'DATA:'
    header[7] = '  - type: tabulated nk'
    header[8] = '    data: |'
    
    export = np.column_stack((w_um, np.real(RefInd), np.imag(RefInd)))
    np.savetxt(filename, export, fmt='%4.3f %#.4g %#.3e', delimiter=' ', header='\n'.join(header), comments='',newline='\n        ')
    return

###############################################################################

## Wavelengths to sample ## 
Example #16
Source File: Tsuda 2018 - PMMA (LD model).py    From refractiveindex.info-scripts with GNU General Public License v3.0 6 votes vote down vote up
def SaveYML(w_um, RefInd, filename, references='', comments=''):
    
    header = np.empty(9, dtype=object)
    header[0] = '# this file is part of refractiveindex.info database'
    header[1] = '# refractiveindex.info database is in the public domain'
    header[2] = '# copyright and related rights waived via CC0 1.0'
    header[3] = ''
    header[4] = 'REFERENCES:' + references
    header[5] = 'COMMENTS:' + comments
    header[6] = 'DATA:'
    header[7] = '  - type: tabulated nk'
    header[8] = '    data: |'
    
    export = np.column_stack((w_um, np.real(RefInd), np.imag(RefInd)))
    np.savetxt(filename, export, fmt='%4.2f %#.4g %#.3e', delimiter=' ', header='\n'.join(header), comments='',newline='\n        ')
    return

###############################################################################

## Wavelengths to sample ## 
Example #17
Source File: Kaiser 1962 - BaF2.py    From refractiveindex.info-scripts with GNU General Public License v3.0 6 votes vote down vote up
def SaveYML(w_um, RefInd, filename, references='', comments=''):
    
    header = np.empty(9, dtype=object)
    header[0] = '# this file is part of refractiveindex.info database'
    header[1] = '# refractiveindex.info database is in the public domain'
    header[2] = '# copyright and related rights waived via CC0 1.0'
    header[3] = ''
    header[4] = 'REFERENCES:' + references
    header[5] = 'COMMENTS:' + comments
    header[6] = 'DATA:'
    header[7] = '  - type: tabulated nk'
    header[8] = '    data: |'
    
    export = np.column_stack((w_um, np.real(RefInd), np.imag(RefInd)))
    np.savetxt(filename, export, fmt='%4.2f %#.4g %#.4g', delimiter=' ', header='\n'.join(header), comments='',newline='\n        ')
    return

###############################################################################

## Wavelengths to sample ## 
Example #18
Source File: misc.py    From tenpy with GNU General Public License v3.0 6 votes vote down vote up
def zero_if_close(a, tol=1.e-15):
    """set real and/or imaginary part to 0 if their absolute value is smaller than `tol`.

    Parameters
    ----------
    a : ndarray
        numpy array to be rounded
    tol : float
        the threashold which values to consider as '0'.
    """
    if a.dtype == np.complex128 or a.dtype == np.complex64:
        ar = np.choose(np.abs(a.real) < tol, [a.real, np.zeros(a.shape)])
        ai = np.choose(np.abs(a.imag) < tol, [a.imag, np.zeros(a.shape)])
        return ar + 1j * ai
    else:
        return np.choose(np.abs(a) < tol, [a, np.zeros_like(a)]) 
Example #19
Source File: audio_utils.py    From ludwig with Apache License 2.0 6 votes vote down vote up
def get_group_delay(raw_data, sampling_rate_in_hz, window_length_in_s,
                    window_shift_in_s, num_fft_points, window_type):
    X_stft_transform = _get_stft(raw_data, sampling_rate_in_hz,
                                 window_length_in_s, window_shift_in_s,
                                 num_fft_points, window_type=window_type)
    Y_stft_transform = _get_stft(raw_data, sampling_rate_in_hz,
                                 window_length_in_s, window_shift_in_s,
                                 num_fft_points, window_type=window_type,
                                 data_transformation='group_delay')
    X_stft_transform_real = np.real(X_stft_transform)
    X_stft_transform_imag = np.imag(X_stft_transform)
    Y_stft_transform_real = np.real(Y_stft_transform)
    Y_stft_transform_imag = np.imag(Y_stft_transform)
    nominator = np.multiply(X_stft_transform_real,
                            Y_stft_transform_real) + np.multiply(
        X_stft_transform_imag, Y_stft_transform_imag)
    denominator = np.square(np.abs(X_stft_transform))
    group_delay = np.divide(nominator, denominator + 1e-10)
    assert not np.isnan(
        group_delay).any(), 'There are NaN values in group delay'
    return np.transpose(group_delay) 
Example #20
Source File: test_type_check.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def test_cmplx(self):
        y = np.random.rand(10,)+1j*np.random.rand(10,)
        assert_array_equal(y.imag, np.imag(y)) 
Example #21
Source File: mosse.py    From OpenCV-Python-Tutorial with MIT License 5 votes vote down vote up
def divSpec(A, B):
    Ar, Ai = A[...,0], A[...,1]
    Br, Bi = B[...,0], B[...,1]
    C = (Ar+1j*Ai)/(Br+1j*Bi)
    C = np.dstack([np.real(C), np.imag(C)]).copy()
    return C 
Example #22
Source File: basic.py    From D-VAE with MIT License 5 votes vote down vote up
def grad(self, inputs, gout):
        (gz,) = gout
        retval = []

        # The following 3 lines verify that gz is complex when the
        # output is complex. The rest of this function make this supposition.
        output_type = self.output_types([i.type for i in inputs])[0]
        if output_type in complex_types:
            if gz.type not in complex_types:
                raise TypeError(
                    'Mul with output_type ' + str(output_type) +
                    ' expected gz type to be complex, got gz with type ' +
                    str(gz.type))

        if output_type in discrete_types:
            return [ipt.zeros_like().astype(theano.config.floatX)
                    for ipt in inputs]

        for input in inputs:
            if gz.type in complex_types:
                # zr+zi = (xr + xi)(yr + yi)
                # zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
                otherprod = mul(*(utils.difference(inputs, [input])))
                yr = real(otherprod)
                yi = imag(otherprod)
                if input.type in complex_types:
                    retval += [complex(yr * real(gz) + yi * imag(gz),
                                       yr * imag(gz) - yi * real(gz))]
                else:
                    retval += [yr * real(gz) + yi * imag(gz)]
            else:
                retval += [mul(*([gz] + utils.difference(inputs,
                                                         [input])))]
        return retval 
Example #23
Source File: basic.py    From D-VAE with MIT License 5 votes vote down vote up
def cast(x, dtype):
    """
    Symbolically cast `x` to a Scalar of given `dtype`.

    """
    if dtype == 'floatX':
        dtype = config.floatX

    _x = as_scalar(x)
    if _x.type.dtype == dtype:
        return _x
    if _x.type.dtype.startswith('complex') and not dtype.startswith('complex'):
        raise TypeError('Casting from complex to real is ambiguous: consider'
                        ' real(), imag(), angle() or abs()')
    return _cast_mapping[dtype](_x) 
Example #24
Source File: basic.py    From D-VAE with MIT License 5 votes vote down vote up
def c_code(self, node, name, inputs, outputs, sub):
        (x,) = inputs
        (z,) = outputs
        type = node.inputs[0].type
        if type in int_types:
            return "%(z)s = abs(%(x)s);" % locals()
        if type in float_types:
            return "%(z)s = fabs(%(x)s);" % locals()
        if type in complex_types:
            return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals()
        raise NotImplementedError('type not supported', type) 
Example #25
Source File: test_type_check.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def test_complex_bad2(self):
        with np.errstate(divide='ignore', invalid='ignore'):
            v = 1 + 1j
            v += np.array(-1+1.j)/0.
        vals = nan_to_num(v)
        assert_all(np.isfinite(vals))
        # Fixme
        #assert_all(vals.imag > 1e10)  and assert_all(np.isfinite(vals))
        # !! This is actually (unexpectedly) positive
        # !! inf.  Comment out for now, and see if it
        # !! changes
        #assert_all(vals.real < -1e10) and assert_all(np.isfinite(vals)) 
Example #26
Source File: imag.py    From mars with Apache License 2.0 5 votes vote down vote up
def imag(val, **kwargs):
    """
    Return the imaginary part of the complex argument.

    Parameters
    ----------
    val : array_like
        Input tensor.

    Returns
    -------
    out : Tensor or scalar
        The imaginary component of the complex argument. If `val` is real,
        the type of `val` is used for the output.  If `val` has complex
        elements, the returned type is float.

    See Also
    --------
    real, angle, real_if_close

    Examples
    --------
    >>> import mars.tensor as mt

    >>> a = mt.array([1+2j, 3+4j, 5+6j])
    >>> a.imag.execute()
    array([ 2.,  4.,  6.])
    >>> a.imag = mt.array([8, 10, 12])
    >>> a.execute()
    array([ 1. +8.j,  3.+10.j,  5.+12.j])
    >>> mt.imag(1 + 1j).execute()
    1.0

    """
    op = TensorImag(**kwargs)
    return op(val) 
Example #27
Source File: yellowfin.py    From YellowFin_MXNet with Apache License 2.0 5 votes vote down vote up
def single_step_mu_lr(self, C, D, h_min, h_max):
    coef = np.array([-1.0, 3.0, 0.0, 1.0])
    coef[2] = -(3 + D ** 2 * h_min ** 2 / 2 / C)
    roots = np.roots(coef)
    root = roots[np.logical_and(np.logical_and(np.real(roots) > 0.0,
                                               np.real(roots) < 1.0), np.imag(roots) < 1e-5)]
    assert root.size == 1
    dr = h_max / h_min
    mu_t = max(np.real(root)[0] ** 2, ((np.sqrt(dr) - 1) / (np.sqrt(dr) + 1)) ** 2)
    lr_t = (1.0 - math.sqrt(mu_t)) ** 2 / h_min
    return mu_t, lr_t 
Example #28
Source File: cls_fe_dft.py    From signaltrain with GNU General Public License v3.0 5 votes vote down vote up
def initialize(self):
        f_matrix = np.fft.fft(np.eye(self.sz), norm='ortho')
        w = sig.hamming(self.sz)

        f_matrix_real = (np.real(f_matrix) * w).astype(np.float32, copy=False)
        f_matrix_imag = (np.imag(f_matrix) * w).astype(np.float32, copy=False)

        if torch.has_cudnn:
            self.conv_analysis_real.weight.data.copy_(torch.from_numpy(f_matrix_real[:, None, :]).cuda())
            self.conv_analysis_imag.weight.data.copy_(torch.from_numpy(f_matrix_imag[:, None, :]).cuda())
        else:
            self.conv_analysis_real.weight.data.copy_(torch.from_numpy(f_matrix_real[:, None, :]))
            self.conv_analysis_imag.weight.data.copy_(torch.from_numpy(f_matrix_imag[:, None, :])) 
Example #29
Source File: cls_fe_dft.py    From signaltrain with GNU General Public License v3.0 5 votes vote down vote up
def initialize(self):
        print('Initializing with Fourier bases')
        f_matrix = np.fft.fft(np.eye(self.sz), norm='ortho')

        f_matrix_real = (np.real(f_matrix)).astype(np.float32, copy=False)
        f_matrix_imag = (np.imag(f_matrix)).astype(np.float32, copy=False)

        if torch.has_cudnn:
            self.fnn_synthesis_real.weight.data.copy_(torch.from_numpy(f_matrix_real.T).cuda())
            self.fnn_synthesis_imag.weight.data.copy_(torch.from_numpy(f_matrix_imag.T).cuda())

        else:
            self.fnn_synthesis_real.weight.data.copy_(torch.from_numpy(f_matrix_real.T))
            self.fnn_synthesis_imag.weight.data.copy_(torch.from_numpy(f_matrix_imag.T)) 
Example #30
Source File: cls_fe_dft.py    From signaltrain with GNU General Public License v3.0 5 votes vote down vote up
def initialize(self):
        f_matrix = np.fft.fft(np.eye(self.sz), norm='ortho')
        w = Synthesis.GLA(self.sz, self.hop, self.sz)

        f_matrix_real = (np.real(f_matrix) * w).astype(np.float32, copy=False)
        f_matrix_imag = (np.imag(f_matrix) * w).astype(np.float32, copy=False)

        if torch.has_cudnn:
            self.conv_synthesis_real.weight.data.copy_(torch.from_numpy(f_matrix_real[:, None, :]).cuda())
            self.conv_synthesis_imag.weight.data.copy_(torch.from_numpy(f_matrix_imag[:, None, :]).cuda())

        else:
            self.conv_synthesis_real.weight.data.copy_(torch.from_numpy(f_matrix_real[:, None, :]))
            self.conv_synthesis_imag.weight.data.copy_(torch.from_numpy(f_matrix_imag[:, None, :]))