Python torch.fft() Examples
The following are 30
code examples of torch.fft().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: torch_spec_operator.py From space_time_pde with MIT License | 8 votes |
def pad_rfft3(f, onesided=True): """ padded batch real fft :param f: tensor of shape [..., res0, res1, res2] """ n0, n1, n2 = f.shape[-3:] h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2) F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2] F2[..., h2, :] = 0 F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1) F1[..., h1,:] = 0 F1 = F1.transpose(-2,-3) F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1) F0[..., h0,:] = 0 F0 = F0.transpose(-2,-4) return F0
Example #2
Source File: transforms.py From fastMRI with MIT License | 6 votes |
def fft2(data): """ Apply centered 2 dimensional Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. Returns: torch.Tensor: The FFT of the input. """ assert data.size(-1) == 2 data = ifftshift(data, dim=(-3, -2)) data = torch.fft(data, 2, normalized=True) data = fftshift(data, dim=(-3, -2)) return data
Example #3
Source File: butterfly_old.py From learning-circuits with Apache License 2.0 | 6 votes |
def test_butterfly_fft(): # DFT matrix for n = 4 size = 4 DFT = torch.fft(real_to_complex(torch.eye(size)), 1) P = real_to_complex(torch.tensor([[1., 0., 0., 0.], [0., 0., 1., 0.], [0., 1., 0., 0.], [0., 0., 0., 1.]])) M0 = Butterfly(size, diagonal=2, complex=True, diag=torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True), subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True), superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True)) M1 = Butterfly(size, diagonal=1, complex=True, diag=torch.tensor([[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True), subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True), superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True)) assert torch.allclose(complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT) br_perm = torch.tensor(bitreversal_permutation(size)) assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT) D = complex_matmul(DFT, P.transpose(0, 1)) assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D)
Example #4
Source File: transforms.py From fastMRI with MIT License | 6 votes |
def fft2(data): """ Apply centered 2 dimensional Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. Returns: torch.Tensor: The FFT of the input. """ assert data.size(-1) == 2 data = ifftshift(data, dim=(-3, -2)) data = torch.fft(data, 2, normalized=True) data = fftshift(data, dim=(-3, -2)) return data
Example #5
Source File: torch_spec_operator.py From space_time_pde with MIT License | 6 votes |
def rfftfreqs(res, dtype=torch.float32, exact=True): """ Helper function to return frequency tensors :param res: n_dims int tuple of number of frequency modes :return: frequency tensor of shape [dim, res, res, res/2+1] """ # print("res",res) n_dims = len(res) freqs = [] for dim in range(n_dims - 1): r_ = res[dim] freq = np.fft.fftfreq(r_, d=1/r_) freqs.append(torch.tensor(freq, dtype=dtype)) r_ = res[-1] if exact: freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype)) else: freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype)) omega = torch.meshgrid(freqs) omega = list(omega) omega = torch.stack(omega, dim=0) # print("omega.shape",omega.shape) return omega
Example #6
Source File: torch_spec_operator.py From space_time_pde with MIT License | 6 votes |
def pad_fft2(f): """ padded batch real fft :param f: tensor of shape [..., res0, res1] """ n0, n1 = f.shape[-2:] h0, h1 = int(n0/2), int(n1/2) # turn f into complex signal f = torch.stack((f, torch.zeros_like(f)), dim=-1) # [..., res0, res1, 2] F1 = torch.fft(f, signal_ndim=1) # [..., res0, res1, 2] F1[..., h1,:] = 0 # [..., res0, res1, 2] F0 = torch.fft(F1.transpose(-3,-2), signal_ndim=1) F0[..., h0,:] = 0 F0 = F0.transpose(-2,-3) return F0
Example #7
Source File: learning_fft_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
Example #8
Source File: utils_deblur.py From KAIR with MIT License | 5 votes |
def get_uperleft_denominator(img, kernel): ''' img: HxWxC kernel: hxw denominator: HxWx1 upperleft: HxWxC ''' V = psf2otf(kernel, img.shape[:2]) denominator = np.expand_dims(np.abs(V)**2, axis=2) upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1]) return upperleft, denominator
Example #9
Source File: learning_fft_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=False) self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
Example #10
Source File: learning_fft_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=False, softmax_fn='sparsemax') self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
Example #11
Source File: learning_fft_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size)) # br_perm = bitreversal_permutation(size) # br_reverse = torch.tensor(list(br_perm[::-1])) # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//2][::-1])), torch.tensor(list(br_perm[size//2:][::-1])))) # Same as [6, 2, 4, 0, 7, 3, 5, 1], which is [0, 1]^4 * [0, 2, 1, 3]^2 * [6, 4, 2, 0, 7, 5, 3, 1] # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//4][::-1])), torch.tensor(list(br_perm[size//4:size//2][::-1])), torch.tensor(list(br_perm[size//2:3*size//4][::-1])), torch.tensor(list(br_perm[3*size//4:][::-1])))) # self.br_perm = br_reverse # self.br_perm = torch.tensor([0, 7, 4, 3, 2, 5, 6, 1]) # Doesn't work # self.br_perm = torch.tensor([7, 3, 0, 4, 2, 6, 5, 1]) # Doesn't work # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 7, 3]) # This works, [0, 1]^4 * [2, 0, 3, 1]^2 * [0, 2, 4, 6, 1, 3, 5, 7] or [1, 0]^4 * [0, 2, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7] # self.br_perm = torch.tensor([4, 0, 2, 6, 5, 1, 3, 7]) # Doesn't work, [0, 1]^4 * [2, 0, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7] # self.br_perm = torch.tensor([1, 5, 3, 7, 0, 4, 2, 6]) # This works, [0, 1]^4 * [4, 6, 5, 7, 0, 4, 2, 6] # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 3, 7]) # Doesn't work # self.br_perm = torch.tensor([4, 0, 6, 2, 1, 5, 3, 7]) # Doesn't work # self.br_perm = torch.tensor([0, 4, 6, 2, 1, 5, 7, 3]) # Doesn't work # self.br_perm = torch.tensor([4, 1, 6, 2, 5, 0, 7, 3]) # This works, since it's just swapping 0 and 1 # self.br_perm = torch.tensor([5, 1, 6, 2, 4, 0, 7, 3]) # This works, since it's swapping 4 and 5
Example #12
Source File: learning_fft_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): torch.manual_seed(config['seed']) self.model = nn.Sequential( BlockPermProduct(size=config['size'], complex=True, share_logit=False), Block2x2DiagProduct(size=config['size'], complex=True) ) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size))) # self.target_matrix = size * torch.ifft(real_to_complex(torch.eye(size))) self.input = real_to_complex(torch.eye(size))
Example #13
Source File: learning_fft_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): torch.manual_seed(config['seed']) self.model = nn.Sequential( Block2x2DiagProduct(size=config['size'], complex=True, decreasing_size=False), BlockPermProduct(size=config['size'], complex=True, share_logit=False, increasing_size=True), ) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.input = real_to_complex(torch.eye(size))
Example #14
Source File: learning_fft_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn'], learn_perm=True) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
Example #15
Source File: butterfly_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def test_block2x2diagproduct(): # Factorization of the DFT matrix size = 4 model = Block2x2DiagProduct(size, complex=True) model.factors[1].ABCD = nn.Parameter(torch.tensor([[[[1.0, 0.0]], [[1.0, 0.0]]], [[[1.0, 0.0]], [[-1.0, 0.0]]]])) model.factors[0].ABCD = nn.Parameter(torch.tensor([[[[1.0, 0.0], [1.0, 0.0]], [[1.0, 0.0], [0.0, -1.0]]], [[[1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [0.0, 1.0]]]])) input = torch.stack((torch.eye(size), torch.zeros(size, size)), dim=-1) assert torch.allclose(model(input[:, [0, 2, 1, 3]]), torch.fft(input, 1))
Example #16
Source File: learning_fft.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=False) self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
Example #17
Source File: learning_fft.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=False, softmax_fn='sparsemax') self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
Example #18
Source File: learning_fft.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size)) # br_perm = bitreversal_permutation(size) # br_reverse = torch.tensor(list(br_perm[::-1])) # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//2][::-1])), torch.tensor(list(br_perm[size//2:][::-1])))) # Same as [6, 2, 4, 0, 7, 3, 5, 1], which is [0, 1]^4 * [0, 2, 1, 3]^2 * [6, 4, 2, 0, 7, 5, 3, 1] # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//4][::-1])), torch.tensor(list(br_perm[size//4:size//2][::-1])), torch.tensor(list(br_perm[size//2:3*size//4][::-1])), torch.tensor(list(br_perm[3*size//4:][::-1])))) # self.br_perm = br_reverse # self.br_perm = torch.tensor([0, 7, 4, 3, 2, 5, 6, 1]) # Doesn't work # self.br_perm = torch.tensor([7, 3, 0, 4, 2, 6, 5, 1]) # Doesn't work # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 7, 3]) # This works, [0, 1]^4 * [2, 0, 3, 1]^2 * [0, 2, 4, 6, 1, 3, 5, 7] or [1, 0]^4 * [0, 2, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7] # self.br_perm = torch.tensor([4, 0, 2, 6, 5, 1, 3, 7]) # Doesn't work, [0, 1]^4 * [2, 0, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7] # self.br_perm = torch.tensor([1, 5, 3, 7, 0, 4, 2, 6]) # This works, [0, 1]^4 * [4, 6, 5, 7, 0, 4, 2, 6] # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 3, 7]) # Doesn't work # self.br_perm = torch.tensor([4, 0, 6, 2, 1, 5, 3, 7]) # Doesn't work # self.br_perm = torch.tensor([0, 4, 6, 2, 1, 5, 7, 3]) # Doesn't work # self.br_perm = torch.tensor([4, 1, 6, 2, 5, 0, 7, 3]) # This works, since it's just swapping 0 and 1 # self.br_perm = torch.tensor([5, 1, 6, 2, 4, 0, 7, 3]) # This works, since it's swapping 4 and 5
Example #19
Source File: learning_fft.py From learning-circuits with Apache License 2.0 | 5 votes |
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn'], learn_perm=True) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
Example #20
Source File: CBP.py From DBCNN-PyTorch with MIT License | 5 votes |
def forward(self, x): bsn = 1 batchSize, dim, h, w = x.data.shape x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim, y = torch.ones(batchSize, self.output_dim, device=x.device) for img in range(batchSize // bsn): segLen = bsn * h * w upper = batchSize * h * w interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long) interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long) batch_x = x_flat[interLarge, :] sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2) sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1) sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2) sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1) Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1]) Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0]) tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0] y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1) y = self._signed_sqrt(y) y = self._l2norm(y) return y
Example #21
Source File: torch_backend.py From kymatio with BSD 3-Clause "New" or "Revised" License | 5 votes |
def fft(input, inverse=False): """Interface with torch FFT routines for 3D signals. fft of a 3d signal Example ------- x = torch.randn(128, 32, 32, 32, 2) x_fft = fft(x) x_ifft = fft(x, inverse=True) Parameters ---------- x : tensor Complex input for the FFT. inverse : bool True for computing the inverse FFT. Raises ------ TypeError In the event that x does not have a final dimension 2 i.e. not complex. Returns ------- output : tensor Result of FFT or IFFT. """ if not _is_complex(input): raise TypeError('The input should be complex (e.g. last dimension is 2)') if inverse: return torch.ifft(input, 3) return torch.fft(input, 3)
Example #22
Source File: utils.py From sigmanet with MIT License | 5 votes |
def torch_fft2(x, normalized=True): """ fft on last 2 dim """ xt = numpy_to_torch(x) kt = torch.fft(xt, 2, normalized) return torch_to_complex_numpy(kt)
Example #23
Source File: utils.py From sigmanet with MIT License | 5 votes |
def torch_fft2c(x, normalized=True): """ fft2 on last 2 dim """ x = np.fft.ifftshift(x, axes=(-2,-1)) xt = numpy_to_torch(x) kt = torch.fft(xt, 2, normalized=True) k = torch_to_complex_numpy(kt) return np.fft.fftshift(k, axes=(-2,-1))
Example #24
Source File: utils.py From sigmanet with MIT License | 5 votes |
def torch_ifft2c(x, normalized=True): """ ifft2 on last 2 dim """ x = np.fft.ifftshift(x, axes=(-2,-1)) xt = numpy_to_torch(x) kt = torch.ifft(xt, 2, normalized=True) k = torch_to_complex_numpy(kt) return np.fft.fftshift(k, axes=(-2,-1))
Example #25
Source File: fft.py From sigmanet with MIT License | 5 votes |
def fft2(data): assert data.size(-1) == 2 data = torch.fft(data, 2, normalized=True) return data
Example #26
Source File: fft.py From sigmanet with MIT License | 5 votes |
def fft2c(data): """ Apply centered 2 dimensional Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. Returns: torch.Tensor: The FFT of the input. """ assert data.size(-1) == 2 data = ifftshift(data, dim=(-3, -2)) data = torch.fft(data, 2, normalized=True) data = fftshift(data, dim=(-3, -2)) return data
Example #27
Source File: fft.py From sigmanet with MIT License | 5 votes |
def ifftshift(x, dim=None): """ Similar to np.fft.ifftshift but applies to PyTorch Tensors """ if dim is None: dim = tuple(range(x.dim())) shift = [(dim + 1) // 2 for dim in x.shape] elif isinstance(dim, int): shift = (x.shape[dim] + 1) // 2 else: shift = [(x.shape[i] + 1) // 2 for i in dim] return roll(x, shift, dim)
Example #28
Source File: mcmri.py From deepinpy with MIT License | 5 votes |
def fft_forw(x, ndim=2): return torch.fft(x, signal_ndim=ndim, normalized=True)
Example #29
Source File: torch_spec_operator.py From space_time_pde with MIT License | 5 votes |
def pad_irfft3(F): """ padded batch inverse real fft :param f: tensor of shape [..., res0, res1, res2/2+1, 2] """ res = F.shape[-3] f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4) f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3) f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2] return f2
Example #30
Source File: CBP.py From fast-MPN-COV with MIT License | 5 votes |
def forward(self, x): bsn = 1 batchSize, dim, h, w = x.data.shape x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim, y = torch.ones(batchSize, self.output_dim, device=x.device) for img in range(batchSize // bsn): segLen = bsn * h * w upper = batchSize * h * w interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long) interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long) batch_x = x_flat[interLarge, :] sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2) sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1) sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2) sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1) Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1]) Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0]) tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0] y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1) y = self._signed_sqrt(y) y = self._l2norm(y) return y