Python torch.fft() Examples

The following are 30 code examples of torch.fft(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: torch_spec_operator.py    From space_time_pde with MIT License 8 votes vote down vote up
def pad_rfft3(f, onesided=True):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1, res2]
    """
    n0, n1, n2 = f.shape[-3:]
    h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2)

    F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2]
    F2[..., h2, :] = 0

    F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1)
    F1[..., h1,:] = 0
    F1 = F1.transpose(-2,-3)

    F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-4)
    return F0 
Example #2
Source File: transforms.py    From fastMRI with MIT License 6 votes vote down vote up
def fft2(data):
    """
    Apply centered 2 dimensional Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The FFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.fft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
Example #3
Source File: butterfly_old.py    From learning-circuits with Apache License 2.0 6 votes vote down vote up
def test_butterfly_fft():
    # DFT matrix for n = 4
    size = 4
    DFT = torch.fft(real_to_complex(torch.eye(size)), 1)
    P = real_to_complex(torch.tensor([[1., 0., 0., 0.],
                                      [0., 0., 1., 0.],
                                      [0., 1., 0., 0.],
                                      [0., 0., 0., 1.]]))
    M0 = Butterfly(size,
                   diagonal=2,
                   complex=True,
                   diag=torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True))
    M1 = Butterfly(size,
                   diagonal=1,
                   complex=True,
                   diag=torch.tensor([[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True))
    assert torch.allclose(complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT)
    br_perm = torch.tensor(bitreversal_permutation(size))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT)
    D = complex_matmul(DFT, P.transpose(0, 1))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D) 
Example #4
Source File: transforms.py    From fastMRI with MIT License 6 votes vote down vote up
def fft2(data):
    """
    Apply centered 2 dimensional Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The FFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.fft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
Example #5
Source File: torch_spec_operator.py    From space_time_pde with MIT License 6 votes vote down vote up
def rfftfreqs(res, dtype=torch.float32, exact=True):
    """
    Helper function to return frequency tensors
    :param res: n_dims int tuple of number of frequency modes
    :return: frequency tensor of shape [dim, res, res, res/2+1]
    """
    # print("res",res)
    n_dims = len(res)
    freqs = []
    for dim in range(n_dims - 1):
        r_ = res[dim]
        freq = np.fft.fftfreq(r_, d=1/r_)
        freqs.append(torch.tensor(freq, dtype=dtype))
    r_ = res[-1]
    if exact:
        freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
    else:
        freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
    omega = torch.meshgrid(freqs)
    omega = list(omega)
    omega = torch.stack(omega, dim=0)

    # print("omega.shape",omega.shape)
    return omega 
Example #6
Source File: torch_spec_operator.py    From space_time_pde with MIT License 6 votes vote down vote up
def pad_fft2(f):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1]
    """
    n0, n1 = f.shape[-2:]
    h0, h1 = int(n0/2), int(n1/2)
    # turn f into complex signal
    f = torch.stack((f, torch.zeros_like(f)), dim=-1) # [..., res0, res1, 2]

    F1 = torch.fft(f, signal_ndim=1) # [..., res0, res1, 2]
    F1[..., h1,:] = 0 # [..., res0, res1, 2]

    F0 = torch.fft(F1.transpose(-3,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-3)
    return F0 
Example #7
Source File: learning_fft_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size, complex=True, fixed_order=True)
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size)) 
Example #8
Source File: utils_deblur.py    From KAIR with MIT License 5 votes vote down vote up
def get_uperleft_denominator(img, kernel):
    '''
    img: HxWxC
    kernel: hxw
    denominator: HxWx1
    upperleft: HxWxC
    '''
    V = psf2otf(kernel, img.shape[:2])
    denominator = np.expand_dims(np.abs(V)**2, axis=2)
    upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
    return upperleft, denominator 
Example #9
Source File: learning_fft_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size, complex=True, fixed_order=False)
        self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size)) 
Example #10
Source File: learning_fft_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size, complex=True, fixed_order=False, softmax_fn='sparsemax')
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size)) 
Example #11
Source File: learning_fft_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=True,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'])
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size))
        # br_perm = bitreversal_permutation(size)
        # br_reverse = torch.tensor(list(br_perm[::-1]))
        # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//2][::-1])), torch.tensor(list(br_perm[size//2:][::-1]))))
        # Same as [6, 2, 4, 0, 7, 3, 5, 1], which is [0, 1]^4 * [0, 2, 1, 3]^2 * [6, 4, 2, 0, 7, 5, 3, 1]
        # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//4][::-1])), torch.tensor(list(br_perm[size//4:size//2][::-1])), torch.tensor(list(br_perm[size//2:3*size//4][::-1])), torch.tensor(list(br_perm[3*size//4:][::-1]))))
        # self.br_perm = br_reverse
        # self.br_perm = torch.tensor([0, 7, 4, 3, 2, 5, 6, 1])  # Doesn't work
        # self.br_perm = torch.tensor([7, 3, 0, 4, 2, 6, 5, 1])  # Doesn't work
        # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 7, 3])  # This works, [0, 1]^4 * [2, 0, 3, 1]^2 * [0, 2, 4, 6, 1, 3, 5, 7] or [1, 0]^4 * [0, 2, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
        # self.br_perm = torch.tensor([4, 0, 2, 6, 5, 1, 3, 7])  # Doesn't work, [0, 1]^4 * [2, 0, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
        # self.br_perm = torch.tensor([1, 5, 3, 7, 0, 4, 2, 6])  # This works, [0, 1]^4 * [4, 6, 5, 7, 0, 4, 2, 6]
        # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 3, 7])  # Doesn't work
        # self.br_perm = torch.tensor([4, 0, 6, 2, 1, 5, 3, 7])  # Doesn't work
        # self.br_perm = torch.tensor([0, 4, 6, 2, 1, 5, 7, 3])  # Doesn't work
        # self.br_perm = torch.tensor([4, 1, 6, 2, 5, 0, 7, 3])  # This works, since it's just swapping 0 and 1
        # self.br_perm = torch.tensor([5, 1, 6, 2, 4, 0, 7, 3])  # This works, since it's swapping 4 and 5 
Example #12
Source File: learning_fft_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = nn.Sequential(
            BlockPermProduct(size=config['size'], complex=True, share_logit=False),
            Block2x2DiagProduct(size=config['size'], complex=True)
        )
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)))
        # self.target_matrix = size * torch.ifft(real_to_complex(torch.eye(size)))
        self.input = real_to_complex(torch.eye(size)) 
Example #13
Source File: learning_fft_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = nn.Sequential(
            Block2x2DiagProduct(size=config['size'], complex=True, decreasing_size=False),
            BlockPermProduct(size=config['size'], complex=True, share_logit=False, increasing_size=True),
        )
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.input = real_to_complex(torch.eye(size)) 
Example #14
Source File: learning_fft_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=True,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'],
                                      learn_perm=True)
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) 
Example #15
Source File: butterfly_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def test_block2x2diagproduct():
    # Factorization of the DFT matrix
    size = 4
    model = Block2x2DiagProduct(size, complex=True)
    model.factors[1].ABCD = nn.Parameter(torch.tensor([[[[1.0, 0.0]], [[1.0, 0.0]]], [[[1.0, 0.0]], [[-1.0, 0.0]]]]))
    model.factors[0].ABCD = nn.Parameter(torch.tensor([[[[1.0, 0.0],
                                                         [1.0, 0.0]],
                                                        [[1.0, 0.0],
                                                         [0.0, -1.0]]],
                                                       [[[1.0, 0.0],
                                                         [1.0, 0.0]],
                                                        [[-1.0, 0.0],
                                                         [0.0, 1.0]]]]))
    input = torch.stack((torch.eye(size), torch.zeros(size, size)), dim=-1)
    assert torch.allclose(model(input[:, [0, 2, 1, 3]]), torch.fft(input, 1)) 
Example #16
Source File: learning_fft.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size, complex=True, fixed_order=False)
        self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size)) 
Example #17
Source File: learning_fft.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size, complex=True, fixed_order=False, softmax_fn='sparsemax')
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size)) 
Example #18
Source File: learning_fft.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=True,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'])
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size))
        # br_perm = bitreversal_permutation(size)
        # br_reverse = torch.tensor(list(br_perm[::-1]))
        # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//2][::-1])), torch.tensor(list(br_perm[size//2:][::-1]))))
        # Same as [6, 2, 4, 0, 7, 3, 5, 1], which is [0, 1]^4 * [0, 2, 1, 3]^2 * [6, 4, 2, 0, 7, 5, 3, 1]
        # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//4][::-1])), torch.tensor(list(br_perm[size//4:size//2][::-1])), torch.tensor(list(br_perm[size//2:3*size//4][::-1])), torch.tensor(list(br_perm[3*size//4:][::-1]))))
        # self.br_perm = br_reverse
        # self.br_perm = torch.tensor([0, 7, 4, 3, 2, 5, 6, 1])  # Doesn't work
        # self.br_perm = torch.tensor([7, 3, 0, 4, 2, 6, 5, 1])  # Doesn't work
        # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 7, 3])  # This works, [0, 1]^4 * [2, 0, 3, 1]^2 * [0, 2, 4, 6, 1, 3, 5, 7] or [1, 0]^4 * [0, 2, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
        # self.br_perm = torch.tensor([4, 0, 2, 6, 5, 1, 3, 7])  # Doesn't work, [0, 1]^4 * [2, 0, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
        # self.br_perm = torch.tensor([1, 5, 3, 7, 0, 4, 2, 6])  # This works, [0, 1]^4 * [4, 6, 5, 7, 0, 4, 2, 6]
        # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 3, 7])  # Doesn't work
        # self.br_perm = torch.tensor([4, 0, 6, 2, 1, 5, 3, 7])  # Doesn't work
        # self.br_perm = torch.tensor([0, 4, 6, 2, 1, 5, 7, 3])  # Doesn't work
        # self.br_perm = torch.tensor([4, 1, 6, 2, 5, 0, 7, 3])  # This works, since it's just swapping 0 and 1
        # self.br_perm = torch.tensor([5, 1, 6, 2, 4, 0, 7, 3])  # This works, since it's swapping 4 and 5 
Example #19
Source File: learning_fft.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=True,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'],
                                      learn_perm=True)
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) 
Example #20
Source File: CBP.py    From DBCNN-PyTorch with MIT License 5 votes vote down vote up
def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y 
Example #21
Source File: torch_backend.py    From kymatio with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def fft(input, inverse=False):
    """Interface with torch FFT routines for 3D signals.
        fft of a 3d signal
        Example
        -------
        x = torch.randn(128, 32, 32, 32, 2)

        x_fft = fft(x)
        x_ifft = fft(x, inverse=True)
        Parameters
        ----------
        x : tensor
            Complex input for the FFT.
        inverse : bool
            True for computing the inverse FFT.

        Raises
        ------
        TypeError
            In the event that x does not have a final dimension 2 i.e. not
            complex.

        Returns
        -------
        output : tensor
            Result of FFT or IFFT.
    """
    if not _is_complex(input):
        raise TypeError('The input should be complex (e.g. last dimension is 2)')
    if inverse:
        return torch.ifft(input, 3)
    return torch.fft(input, 3) 
Example #22
Source File: utils.py    From sigmanet with MIT License 5 votes vote down vote up
def torch_fft2(x, normalized=True):
    """ fft on last 2 dim """
    xt = numpy_to_torch(x)
    kt = torch.fft(xt, 2, normalized)
    return torch_to_complex_numpy(kt) 
Example #23
Source File: utils.py    From sigmanet with MIT License 5 votes vote down vote up
def torch_fft2c(x, normalized=True):
    """ fft2 on last 2 dim """
    x = np.fft.ifftshift(x, axes=(-2,-1))
    xt = numpy_to_torch(x)
    kt = torch.fft(xt, 2, normalized=True)
    k = torch_to_complex_numpy(kt)
    return np.fft.fftshift(k, axes=(-2,-1)) 
Example #24
Source File: utils.py    From sigmanet with MIT License 5 votes vote down vote up
def torch_ifft2c(x, normalized=True):
    """ ifft2 on last 2 dim """
    x = np.fft.ifftshift(x, axes=(-2,-1))
    xt = numpy_to_torch(x)
    kt = torch.ifft(xt, 2, normalized=True)
    k = torch_to_complex_numpy(kt)
    return np.fft.fftshift(k, axes=(-2,-1)) 
Example #25
Source File: fft.py    From sigmanet with MIT License 5 votes vote down vote up
def fft2(data):
    assert data.size(-1) == 2
    data = torch.fft(data, 2, normalized=True)
    return data 
Example #26
Source File: fft.py    From sigmanet with MIT License 5 votes vote down vote up
def fft2c(data):
    """
    Apply centered 2 dimensional Fast Fourier Transform.
    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.
    Returns:
        torch.Tensor: The FFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.fft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
Example #27
Source File: fft.py    From sigmanet with MIT License 5 votes vote down vote up
def ifftshift(x, dim=None):
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    """
    if dim is None:
        dim = tuple(range(x.dim()))
        shift = [(dim + 1) // 2 for dim in x.shape]
    elif isinstance(dim, int):
        shift = (x.shape[dim] + 1) // 2
    else:
        shift = [(x.shape[i] + 1) // 2 for i in dim]
    return roll(x, shift, dim) 
Example #28
Source File: mcmri.py    From deepinpy with MIT License 5 votes vote down vote up
def fft_forw(x, ndim=2):
    return torch.fft(x, signal_ndim=ndim, normalized=True) 
Example #29
Source File: torch_spec_operator.py    From space_time_pde with MIT License 5 votes vote down vote up
def pad_irfft3(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    res = F.shape[-3]
    f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4)
    f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2]
    return f2 
Example #30
Source File: CBP.py    From fast-MPN-COV with MIT License 5 votes vote down vote up
def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y