'''
@file: CBP.py
@author: Chunqiao Xu
@author: Jiangtao Xie
@author: Peihua Li
'''
import torch
import torch.nn as nn
class CBP(nn.Module):
     """Compact Bilinear Pooling
        implementation of Compact Bilinear Pooling (CBP)
        https://arxiv.org/pdf/1511.06062.pdf

     Args:
         thresh: small positive number for computation stability
         projDim: projected dimension
         input_dim: the #channel of input feature
     """
     def __init__(self, thresh=1e-8, projDim=8192, input_dim=512):
         super(CBP, self).__init__()
         self.thresh = thresh
         self.projDim = projDim
         self.input_dim = input_dim
         self.output_dim = projDim
         torch.manual_seed(1)
         self.h_ = [
                 torch.randint(0, self.output_dim, (self.input_dim,),dtype=torch.long),
                 torch.randint(0, self.output_dim, (self.input_dim,),dtype=torch.long)
         ]
         self.weights_ = [
             (2 * torch.randint(0, 2, (self.input_dim,)) - 1).float(),
             (2 * torch.randint(0, 2, (self.input_dim,)) - 1).float()
         ]

         indices1 = torch.cat((torch.arange(input_dim, dtype=torch.long).reshape(1, -1),
                               self.h_[0].reshape(1, -1)), dim=0)
         indices2 = torch.cat((torch.arange(input_dim, dtype=torch.long).reshape(1, -1),
                               self.h_[1].reshape(1, -1)), dim=0)

         self.sparseM = [
             torch.sparse.FloatTensor(indices1, self.weights_[0], torch.Size([self.input_dim, self.output_dim])).to_dense(),
             torch.sparse.FloatTensor(indices2, self.weights_[1], torch.Size([self.input_dim, self.output_dim])).to_dense(),
         ]
     def _signed_sqrt(self, x):
         x = torch.mul(x.sign(), torch.sqrt(x.abs()+self.thresh))
         return x

     def _l2norm(self, x):
         x = nn.functional.normalize(x)
         return x

     def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y