Python torch.ifft() Examples
The following are 22
code examples of torch.ifft().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: transforms.py From fastMRI with MIT License | 6 votes |
def ifft2(data): """ Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. Returns: torch.Tensor: The IFFT of the input. """ assert data.size(-1) == 2 data = ifftshift(data, dim=(-3, -2)) data = torch.ifft(data, 2, normalized=True) data = fftshift(data, dim=(-3, -2)) return data
Example #2
Source File: transforms.py From fastMRI with MIT License | 6 votes |
def ifft2(data): """ Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. Returns: torch.Tensor: The IFFT of the input. """ assert data.size(-1) == 2 data = ifftshift(data, dim=(-3, -2)) data = torch.ifft(data, 2, normalized=True) data = fftshift(data, dim=(-3, -2)) return data
Example #3
Source File: utils_deblur.py From KAIR with MIT License | 5 votes |
def ifft(t): return torch.ifft(t, 2)
Example #4
Source File: mcmri.py From deepinpy with MIT License | 5 votes |
def fft_adj(x, ndim=2): return torch.ifft(x, signal_ndim=ndim, normalized=True)
Example #5
Source File: fft.py From sigmanet with MIT License | 5 votes |
def ifft2c(data): """ Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. Returns: torch.Tensor: The IFFT of the input. """ assert data.size(-1) == 2 data = ifftshift(data, dim=(-3, -2)) data = torch.ifft(data, 2, normalized=True) data = fftshift(data, dim=(-3, -2)) return data
Example #6
Source File: fft.py From sigmanet with MIT License | 5 votes |
def ifft2(data): assert data.size(-1) == 2 data = torch.ifft(data, 2, normalized=True) return data
Example #7
Source File: utils.py From sigmanet with MIT License | 5 votes |
def torch_ifft2c(x, normalized=True): """ ifft2 on last 2 dim """ x = np.fft.ifftshift(x, axes=(-2,-1)) xt = numpy_to_torch(x) kt = torch.ifft(xt, 2, normalized=True) k = torch_to_complex_numpy(kt) return np.fft.fftshift(k, axes=(-2,-1))
Example #8
Source File: utils.py From sigmanet with MIT License | 5 votes |
def torch_ifft2(k, normalized=True): """ ifft on last 2 dim """ kt = numpy_to_torch(k) xt = torch.ifft(kt, 2, normalized) return torch_to_complex_numpy(xt)
Example #9
Source File: torch_backend.py From kymatio with BSD 3-Clause "New" or "Revised" License | 5 votes |
def fft(input, inverse=False): """Interface with torch FFT routines for 3D signals. fft of a 3d signal Example ------- x = torch.randn(128, 32, 32, 32, 2) x_fft = fft(x) x_ifft = fft(x, inverse=True) Parameters ---------- x : tensor Complex input for the FFT. inverse : bool True for computing the inverse FFT. Raises ------ TypeError In the event that x does not have a final dimension 2 i.e. not complex. Returns ------- output : tensor Result of FFT or IFFT. """ if not _is_complex(input): raise TypeError('The input should be complex (e.g. last dimension is 2)') if inverse: return torch.ifft(input, 3) return torch.fft(input, 3)
Example #10
Source File: CBP.py From DBCNN-PyTorch with MIT License | 5 votes |
def forward(self, x): bsn = 1 batchSize, dim, h, w = x.data.shape x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim, y = torch.ones(batchSize, self.output_dim, device=x.device) for img in range(batchSize // bsn): segLen = bsn * h * w upper = batchSize * h * w interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long) interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long) batch_x = x_flat[interLarge, :] sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2) sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1) sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2) sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1) Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1]) Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0]) tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0] y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1) y = self._signed_sqrt(y) y = self._l2norm(y) return y
Example #11
Source File: learning_fft_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): torch.manual_seed(config['seed']) self.model = nn.Sequential( BlockPermProduct(size=config['size'], complex=True, share_logit=False), Block2x2DiagProduct(size=config['size'], complex=True) ) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size))) # self.target_matrix = size * torch.ifft(real_to_complex(torch.eye(size))) self.input = real_to_complex(torch.eye(size))
Example #12
Source File: pooling.py From Landmark2019-1st-and-3rd-Place-Solution with Apache License 2.0 | 5 votes |
def forward(self, bottom): batch_size, _, height, width = bottom.size() bottom_flat = bottom.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1) sketch_1 = bottom_flat.mm(self.sparse_sketch_matrix1) sketch_2 = bottom_flat.mm(self.sparse_sketch_matrix2) im_zeros_1 = torch.zeros(sketch_1.size()).to(sketch_1.device) im_zeros_2 = torch.zeros(sketch_2.size()).to(sketch_2.device) fft1 = torch.fft(torch.cat([sketch_1.unsqueeze(-1), im_zeros_1.unsqueeze(-1)], dim=-1), 1) fft2 = torch.fft(torch.cat([sketch_2.unsqueeze(-1), im_zeros_2.unsqueeze(-1)], dim=-1), 1) fft_product_real = fft1[..., 0].mul(fft2[..., 0]) - fft1[..., 1].mul(fft2[..., 1]) fft_product_imag = fft1[..., 0].mul(fft2[..., 1]) + fft1[..., 1].mul(fft2[..., 0]) cbp_flat = torch.ifft(torch.cat([ fft_product_real.unsqueeze(-1), fft_product_imag.unsqueeze(-1)], dim=-1), 1)[..., 0] cbp = cbp_flat.view(batch_size, height, width, self.output_dim) if self.sum_pool: cbp = cbp.sum(dim=[1, 2]) return cbp
Example #13
Source File: so3_fft.py From s2cnn with MIT License | 5 votes |
def backward(self, grad_output): # pylint: disable=W # ifft of grad_output is not necessarily real, therefore we cannot use rifft return so3_ifft(grad_output, for_grad=True, b_out=self.b_in)[..., 0], None
Example #14
Source File: so3_fft.py From s2cnn with MIT License | 5 votes |
def so3_ifft(x, for_grad=False, b_out=None): ''' :param x: [l * m * n, ..., complex] ''' assert x.size(-1) == 2 nspec = x.size(0) b_in = round((3 / 4 * nspec) ** (1 / 3)) assert nspec == b_in * (4 * b_in ** 2 - 1) // 3 if b_out is None: b_out = b_in batch_size = x.size()[1:-1] x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2) ''' :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2) :return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2)) if x.is_cuda and x.dtype == torch.float32: cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index) cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex] else: output.fill_(0) for l in range(min(b_in, b_out)): s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2) out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1))) l1 = min(l, b_out - 1) # if b_out < b_in output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1] if l > 0: output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1] output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l] output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l] output = torch.ifft(output, 2) * output.size(-2) ** 2 # [batch, beta, alpha, gamma, complex] output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out, 2) return output
Example #15
Source File: torch_spec_operator.py From space_time_pde with MIT License | 5 votes |
def pad_ifft2(F): """ padded batch inverse real fft :param f: tensor of shape [..., res0, res1, res2/2+1, 2] """ f0 = torch.ifft(F.transpose(-3,-2), signal_ndim=1).transpose(-2,-3) f1 = torch.ifft(f0, signal_ndim=1) return f2
Example #16
Source File: torch_spec_operator.py From space_time_pde with MIT License | 5 votes |
def pad_irfft3(F): """ padded batch inverse real fft :param f: tensor of shape [..., res0, res1, res2/2+1, 2] """ res = F.shape[-3] f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4) f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3) f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2] return f2
Example #17
Source File: network_usrnet.py From KAIR with MIT License | 5 votes |
def ifft(t): # Complex-to-complex Inverse Discrete Fourier Transform return torch.ifft(t, 2)
Example #18
Source File: utils_sisr.py From KAIR with MIT License | 5 votes |
def ifft(t): return torch.ifft(t, 2)
Example #19
Source File: CBP.py From fast-MPN-COV with MIT License | 5 votes |
def forward(self, x): bsn = 1 batchSize, dim, h, w = x.data.shape x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim, y = torch.ones(batchSize, self.output_dim, device=x.device) for img in range(batchSize // bsn): segLen = bsn * h * w upper = batchSize * h * w interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long) interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long) batch_x = x_flat[interLarge, :] sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2) sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1) sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2) sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1) Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1]) Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0]) tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0] y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1) y = self._signed_sqrt(y) y = self._l2norm(y) return y
Example #20
Source File: so3_fft.py From s2cnn with MIT License | 4 votes |
def so3_rifft(x, for_grad=False, b_out=None): ''' :param x: [l * m * n, ..., complex] ''' assert x.size(-1) == 2 nspec = x.size(0) b_in = round((3 / 4 * nspec) ** (1 / 3)) assert nspec == b_in * (4 * b_in ** 2 - 1) // 3 if b_out is None: b_out = b_in batch_size = x.size()[1:-1] x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2) ''' :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2) :return: [batch, beta, alpha, gamma] (nbatch, 2 b_out, 2 b_out, 2 b_out) ''' nbatch = x.size(1) wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2)) if x.is_cuda and x.dtype == torch.float32: cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=True, device=x.device.index) cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex] else: # TODO can be optimized knowing that the output is real, like in _setup_so3ifft_cuda_kernel(real_output=True) output.fill_(0) for l in range(min(b_in, b_out)): s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2) out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1))) l1 = min(l, b_out - 1) # if b_out < b_in output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1] if l > 0: output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1] output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l] output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l] output = torch.ifft(output, 2) * output.size(-2) ** 2 # [batch, beta, alpha, gamma, complex] output = output[..., 0] # [batch, beta, alpha, gamma] output = output.contiguous() output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out) return output
Example #21
Source File: s2_fft.py From s2cnn with MIT License | 4 votes |
def s2_ifft(x, for_grad=False, b_out=None): ''' :param x: [l * m, ..., complex] ''' assert x.size(-1) == 2 nspec = x.size(0) b_in = round(nspec ** 0.5) assert nspec == b_in ** 2 if b_out is None: b_out = b_in assert b_out >= b_in batch_size = x.size()[1:-1] x = x.view(nspec, -1, 2) # [l * m, batch, complex] (nspec, nbatch, 2) ''' :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) else: output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2)) for l in range(b_in): s = slice(l ** 2, l ** 2 + 2 * l + 1) out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s])) output[:, :, :l + 1] += out[:, :, -l - 1:] if l > 0: output[:, :, -l:] += out[:, :, :l] output = torch.ifft(output, 1) * output.size(-2) # [batch, beta, alpha, complex] output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2) return output
Example #22
Source File: fft_functions.py From torchkbnufft with MIT License | 4 votes |
def fft_filter(x, kern, norm=None): """FFT-based filtering on a 2-size oversampled grid. """ im_size = torch.tensor(x.shape).to(torch.long)[3:] grid_size = im_size * 2 # set up n-dimensional zero pad pad_sizes = [] permute_dims = [0, 1] inv_permute_dims = [0, 1, 2 + grid_size.shape[0]] for i in range(grid_size.shape[0]): pad_sizes.append(0) pad_sizes.append(int(grid_size[-1 - i] - im_size[-1 - i])) permute_dims.append(3 + i) inv_permute_dims.append(2 + i) permute_dims.append(2) pad_sizes = tuple(pad_sizes) permute_dims = tuple(permute_dims) inv_permute_dims = tuple(inv_permute_dims) # zero pad and fft x = F.pad(x, pad_sizes) x = x.permute(permute_dims) x = torch.fft(x, grid_size.numel()) if norm == 'ortho': x = x / torch.sqrt(torch.prod(grid_size.to(torch.double))) x = x.permute(inv_permute_dims) # apply the filter x = complex_mult(x, kern, dim=2) # inverse fft x = x.permute(permute_dims) x = torch.ifft(x, grid_size.numel()) x = x.permute(inv_permute_dims) # crop to input size crop_starts = tuple(np.array(x.shape).astype(np.int) * 0) crop_ends = [x.shape[0], x.shape[1], x.shape[2]] for dim in im_size: crop_ends.append(int(dim)) x = x[tuple(map(slice, crop_starts, crop_ends))] # scaling, assume user handled adjoint scaling with their kernel if norm == 'ortho': x = x / torch.sqrt(torch.prod(grid_size.to(torch.double))) return x