Python torch.ifft() Examples

The following are 22 code examples of torch.ifft(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: transforms.py    From fastMRI with MIT License 6 votes vote down vote up
def ifft2(data):
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The IFFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.ifft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
Example #2
Source File: transforms.py    From fastMRI with MIT License 6 votes vote down vote up
def ifft2(data):
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The IFFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.ifft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
Example #3
Source File: utils_deblur.py    From KAIR with MIT License 5 votes vote down vote up
def ifft(t):
    return torch.ifft(t, 2) 
Example #4
Source File: mcmri.py    From deepinpy with MIT License 5 votes vote down vote up
def fft_adj(x, ndim=2):
    return torch.ifft(x, signal_ndim=ndim, normalized=True) 
Example #5
Source File: fft.py    From sigmanet with MIT License 5 votes vote down vote up
def ifft2c(data):
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.
    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.
    Returns:
        torch.Tensor: The IFFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.ifft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
Example #6
Source File: fft.py    From sigmanet with MIT License 5 votes vote down vote up
def ifft2(data):
    assert data.size(-1) == 2
    data = torch.ifft(data, 2, normalized=True)
    return data 
Example #7
Source File: utils.py    From sigmanet with MIT License 5 votes vote down vote up
def torch_ifft2c(x, normalized=True):
    """ ifft2 on last 2 dim """
    x = np.fft.ifftshift(x, axes=(-2,-1))
    xt = numpy_to_torch(x)
    kt = torch.ifft(xt, 2, normalized=True)
    k = torch_to_complex_numpy(kt)
    return np.fft.fftshift(k, axes=(-2,-1)) 
Example #8
Source File: utils.py    From sigmanet with MIT License 5 votes vote down vote up
def torch_ifft2(k, normalized=True):
    """ ifft on last 2 dim """
    kt = numpy_to_torch(k)
    xt = torch.ifft(kt, 2, normalized)
    return torch_to_complex_numpy(xt) 
Example #9
Source File: torch_backend.py    From kymatio with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def fft(input, inverse=False):
    """Interface with torch FFT routines for 3D signals.
        fft of a 3d signal
        Example
        -------
        x = torch.randn(128, 32, 32, 32, 2)

        x_fft = fft(x)
        x_ifft = fft(x, inverse=True)
        Parameters
        ----------
        x : tensor
            Complex input for the FFT.
        inverse : bool
            True for computing the inverse FFT.

        Raises
        ------
        TypeError
            In the event that x does not have a final dimension 2 i.e. not
            complex.

        Returns
        -------
        output : tensor
            Result of FFT or IFFT.
    """
    if not _is_complex(input):
        raise TypeError('The input should be complex (e.g. last dimension is 2)')
    if inverse:
        return torch.ifft(input, 3)
    return torch.fft(input, 3) 
Example #10
Source File: CBP.py    From DBCNN-PyTorch with MIT License 5 votes vote down vote up
def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y 
Example #11
Source File: learning_fft_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = nn.Sequential(
            BlockPermProduct(size=config['size'], complex=True, share_logit=False),
            Block2x2DiagProduct(size=config['size'], complex=True)
        )
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)))
        # self.target_matrix = size * torch.ifft(real_to_complex(torch.eye(size)))
        self.input = real_to_complex(torch.eye(size)) 
Example #12
Source File: pooling.py    From Landmark2019-1st-and-3rd-Place-Solution with Apache License 2.0 5 votes vote down vote up
def forward(self, bottom):

        batch_size, _, height, width = bottom.size()

        bottom_flat = bottom.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)

        sketch_1 = bottom_flat.mm(self.sparse_sketch_matrix1)
        sketch_2 = bottom_flat.mm(self.sparse_sketch_matrix2)

        im_zeros_1 = torch.zeros(sketch_1.size()).to(sketch_1.device)
        im_zeros_2 = torch.zeros(sketch_2.size()).to(sketch_2.device)
        fft1 = torch.fft(torch.cat([sketch_1.unsqueeze(-1), im_zeros_1.unsqueeze(-1)], dim=-1), 1)
        fft2 = torch.fft(torch.cat([sketch_2.unsqueeze(-1), im_zeros_2.unsqueeze(-1)], dim=-1), 1)

        fft_product_real = fft1[..., 0].mul(fft2[..., 0]) - fft1[..., 1].mul(fft2[..., 1])
        fft_product_imag = fft1[..., 0].mul(fft2[..., 1]) + fft1[..., 1].mul(fft2[..., 0])

        cbp_flat = torch.ifft(torch.cat([
            fft_product_real.unsqueeze(-1),
            fft_product_imag.unsqueeze(-1)],
            dim=-1), 1)[..., 0]

        cbp = cbp_flat.view(batch_size, height, width, self.output_dim)

        if self.sum_pool:
            cbp = cbp.sum(dim=[1, 2])

        return cbp 
Example #13
Source File: so3_fft.py    From s2cnn with MIT License 5 votes vote down vote up
def backward(self, grad_output):  # pylint: disable=W
        # ifft of grad_output is not necessarily real, therefore we cannot use rifft
        return so3_ifft(grad_output, for_grad=True, b_out=self.b_in)[..., 0], None 
Example #14
Source File: so3_fft.py    From s2cnn with MIT License 5 votes vote down vote up
def so3_ifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m * n, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round((3 / 4 * nspec) ** (1 / 3))
    assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m * n, batch, complex] (nspec, nbatch, 2)

    '''
    :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
    :return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)  # [beta, l * m * n] (2 * b_out, nspec)

    output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
    if x.is_cuda and x.dtype == torch.float32:
        cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index)
        cuda_kernel(x, wigner, output)  # [batch, beta, m, n, complex]
    else:
        output.fill_(0)
        for l in range(min(b_in, b_out)):
            s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
            out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
            l1 = min(l, b_out - 1)  # if b_out < b_in
            output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
            if l > 0:
                output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
                output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
                output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]

    output = torch.ifft(output, 2) * output.size(-2) ** 2  # [batch, beta, alpha, gamma, complex]    
    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out, 2)
    return output 
Example #15
Source File: torch_spec_operator.py    From space_time_pde with MIT License 5 votes vote down vote up
def pad_ifft2(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    f0 = torch.ifft(F.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f1 = torch.ifft(f0, signal_ndim=1)
    return f2 
Example #16
Source File: torch_spec_operator.py    From space_time_pde with MIT License 5 votes vote down vote up
def pad_irfft3(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    res = F.shape[-3]
    f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4)
    f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2]
    return f2 
Example #17
Source File: network_usrnet.py    From KAIR with MIT License 5 votes vote down vote up
def ifft(t):
    # Complex-to-complex Inverse Discrete Fourier Transform
    return torch.ifft(t, 2) 
Example #18
Source File: utils_sisr.py    From KAIR with MIT License 5 votes vote down vote up
def ifft(t):
    return torch.ifft(t, 2) 
Example #19
Source File: CBP.py    From fast-MPN-COV with MIT License 5 votes vote down vote up
def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y 
Example #20
Source File: so3_fft.py    From s2cnn with MIT License 4 votes vote down vote up
def so3_rifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m * n, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round((3 / 4 * nspec) ** (1 / 3))
    assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m * n, batch, complex] (nspec, nbatch, 2)

    '''
    :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
    :return: [batch, beta, alpha, gamma] (nbatch, 2 b_out, 2 b_out, 2 b_out)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)  # [beta, l * m * n] (2 * b_out, nspec)

    output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
    if x.is_cuda and x.dtype == torch.float32:
        cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=True, device=x.device.index)
        cuda_kernel(x, wigner, output)  # [batch, beta, m, n, complex]
    else:
        # TODO can be optimized knowing that the output is real, like in _setup_so3ifft_cuda_kernel(real_output=True)
        output.fill_(0)
        for l in range(min(b_in, b_out)):
            s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
            out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
            l1 = min(l, b_out - 1)  # if b_out < b_in
            output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
            if l > 0:
                output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
                output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
                output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]

    output = torch.ifft(output, 2) * output.size(-2) ** 2  # [batch, beta, alpha, gamma, complex]
    output = output[..., 0]  # [batch, beta, alpha, gamma]
    output = output.contiguous()

    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out)
    return output 
Example #21
Source File: s2_fft.py    From s2cnn with MIT License 4 votes vote down vote up
def s2_ifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round(nspec ** 0.5)
    assert nspec == b_in ** 2
    if b_out is None:
        b_out = b_in
    assert b_out >= b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m, batch, complex] (nspec, nbatch, 2)

    '''
    :param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
    :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)
    wigner = wigner.view(2 * b_out, -1)  # [beta, l * m] (2 * b_out, nspec)

    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1),
                    args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()],
                    stream=stream)
        # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
    else:
        output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
        for l in range(b_in):
            s = slice(l ** 2, l ** 2 + 2 * l + 1)
            out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
            output[:, :, :l + 1] += out[:, :, -l - 1:]
            if l > 0:
                output[:, :, -l:] += out[:, :, :l]

    output = torch.ifft(output, 1) * output.size(-2)  # [batch, beta, alpha, complex]
    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2)
    return output 
Example #22
Source File: fft_functions.py    From torchkbnufft with MIT License 4 votes vote down vote up
def fft_filter(x, kern, norm=None):
    """FFT-based filtering on a 2-size oversampled grid.
    """
    im_size = torch.tensor(x.shape).to(torch.long)[3:]
    grid_size = im_size * 2

    # set up n-dimensional zero pad
    pad_sizes = []
    permute_dims = [0, 1]
    inv_permute_dims = [0, 1, 2 + grid_size.shape[0]]
    for i in range(grid_size.shape[0]):
        pad_sizes.append(0)
        pad_sizes.append(int(grid_size[-1 - i] - im_size[-1 - i]))
        permute_dims.append(3 + i)
        inv_permute_dims.append(2 + i)
    permute_dims.append(2)
    pad_sizes = tuple(pad_sizes)
    permute_dims = tuple(permute_dims)
    inv_permute_dims = tuple(inv_permute_dims)

    # zero pad and fft
    x = F.pad(x, pad_sizes)
    x = x.permute(permute_dims)
    x = torch.fft(x, grid_size.numel())
    if norm == 'ortho':
        x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
    x = x.permute(inv_permute_dims)

    # apply the filter
    x = complex_mult(x, kern, dim=2)

    # inverse fft
    x = x.permute(permute_dims)
    x = torch.ifft(x, grid_size.numel())
    x = x.permute(inv_permute_dims)

    # crop to input size
    crop_starts = tuple(np.array(x.shape).astype(np.int) * 0)
    crop_ends = [x.shape[0], x.shape[1], x.shape[2]]
    for dim in im_size:
        crop_ends.append(int(dim))
    x = x[tuple(map(slice, crop_starts, crop_ends))]

    # scaling, assume user handled adjoint scaling with their kernel
    if norm == 'ortho':
        x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))

    return x