import types import torch import torch.nn as nn from torch.autograd import Function def CountSketchFn_forward(h, s, output_size, x, force_cpu_scatter_add=False): x_size = tuple(x.size()) s_view = (1,) * (len(x_size)-1) + (x_size[-1],) out_size = x_size[:-1] + (output_size,) # Broadcast s and compute x * s s = s.view(s_view) xs = x * s # Broadcast h then compute h: # out[h_i] += x_i * s_i h = h.view(s_view).expand(x_size) if force_cpu_scatter_add: out = x.new(*out_size).zero_().cpu() return out.scatter_add_(-1, h.cpu(), xs.cpu()).cuda() else: out = x.new(*out_size).zero_() return out.scatter_add_(-1, h, xs) def CountSketchFn_backward(h, s, x_size, grad_output): s_view = (1,) * (len(x_size)-1) + (x_size[-1],) s = s.view(s_view) h = h.view(s_view).expand(x_size) grad_x = grad_output.gather(-1, h) grad_x = grad_x * s return grad_x class CountSketchFn(Function): @staticmethod def forward(ctx, h, s, output_size, x, force_cpu_scatter_add=False): x_size = tuple(x.size()) ctx.save_for_backward(h,s) ctx.x_size = tuple(x.size()) return CountSketchFn_forward(h, s, output_size, x, force_cpu_scatter_add) @staticmethod def backward(ctx, grad_output): h,s = ctx.saved_variables grad_x = CountSketchFn_backward(h,s,ctx.x_size,grad_output) return None, None, None, grad_x class CountSketch(nn.Module): r"""Compute the count sketch over an input signal. .. math:: out_j = \sum_{i : j = h_i} s_i x_i Args: input_size (int): Number of channels in the input array output_size (int): Number of channels in the output sketch h (array, optional): Optional array of size input_size of indices in the range [0,output_size] s (array, optional): Optional array of size input_size of -1 and 1. .. note:: If h and s are None, they will be automatically be generated using LongTensor.random_. Shape: - Input: (...,input_size) - Output: (...,output_size) References: Yang Gao et al. "Compact Bilinear Pooling" in Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (2016). Akira Fukui et al. "Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding", arXiv:1606.01847 (2016). """ def __init__(self, input_size, output_size, h = None, s = None): super(CountSketch, self).__init__() self.input_size = input_size self.output_size = output_size if h is None: h = torch.LongTensor(input_size).random_(0, output_size) if s is None: s = 2 * torch.Tensor(input_size).random_(0,2) - 1 # The Variable h being a list of indices, # If the type of this module is changed (e.g. float to double), # the variable h should remain a LongTensor # therefore we force float() and double() to be no-ops on the variable h. def identity(self): return self h.float = types.MethodType(identity,h) h.double = types.MethodType(identity,h) self.register_buffer('h',h) self.register_buffer('s',s) def forward(self, x): x_size = list(x.size()) assert(x_size[-1] == self.input_size) return CountSketchFn.apply(self.h, self.s, self.output_size, x) def ComplexMultiply_forward(X_re, X_im, Y_re, Y_im): Z_re = torch.addcmul(X_re*Y_re, -1, X_im, Y_im) Z_im = torch.addcmul(X_re*Y_im, 1, X_im, Y_re) return Z_re,Z_im def ComplexMultiply_backward(X_re, X_im, Y_re, Y_im, grad_Z_re, grad_Z_im): grad_X_re = torch.addcmul(grad_Z_re * Y_re, 1, grad_Z_im, Y_im) grad_X_im = torch.addcmul(grad_Z_im * Y_re, -1, grad_Z_re, Y_im) grad_Y_re = torch.addcmul(grad_Z_re * X_re, 1, grad_Z_im, X_im) grad_Y_im = torch.addcmul(grad_Z_im * X_re, -1, grad_Z_re, X_im) return grad_X_re,grad_X_im,grad_Y_re,grad_Y_im class ComplexMultiply(torch.autograd.Function): @staticmethod def forward(ctx, X_re, X_im, Y_re, Y_im): ctx.save_for_backward(X_re,X_im,Y_re,Y_im) return ComplexMultiply_forward(X_re, X_im, Y_re, Y_im) @staticmethod def backward(ctx,grad_Z_re, grad_Z_im): X_re,X_im,Y_re,Y_im = ctx.saved_tensors return ComplexMultiply_backward(X_re,X_im,Y_re,Y_im, grad_Z_re, grad_Z_im) class CompactBilinearPoolingFn(Function): @staticmethod def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False): ctx.save_for_backward(h1,s1,h2,s2,x,y) ctx.x_size = tuple(x.size()) ctx.y_size = tuple(y.size()) ctx.force_cpu_scatter_add = force_cpu_scatter_add ctx.output_size = output_size # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add) fx = torch.rfft(px,1) re_fx = fx.select(-1, 0) im_fx = fx.select(-1, 1) del px py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add) fy = torch.rfft(py,1) re_fy = fy.select(-1,0) im_fy = fy.select(-1,1) del py # Convolution of the two sketch using an FFT. # Compute the FFT of each sketch # Complex multiplication re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy) # Back to real domain # The imaginary part should be zero's re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,)) return re @staticmethod def backward(ctx,grad_output): h1,s1,h2,s2,x,y = ctx.saved_tensors # Recompute part of the forward pass to get the input to the complex product # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add) py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add) # Then convert the output to Fourier domain grad_output = grad_output.contiguous() grad_prod = torch.rfft(grad_output, 1) grad_re_prod = grad_prod.select(-1, 0) grad_im_prod = grad_prod.select(-1, 1) # Compute the gradient of x first then y # Gradient of x # Recompute fy fy = torch.rfft(py,1) re_fy = fy.select(-1,0) im_fy = fy.select(-1,1) del py # Compute the gradient of fx, then back to temporal space grad_re_fx = torch.addcmul(grad_re_prod * re_fy, 1, grad_im_prod, im_fy) grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy) grad_fx = torch.irfft(torch.stack((grad_re_fx,grad_im_fx), grad_re_fx.dim()), 1, signal_sizes=(ctx.output_size,)) # Finally compute the gradient of x grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx) del re_fy,im_fy,grad_re_fx,grad_im_fx,grad_fx # Gradient of y # Recompute fx fx = torch.rfft(px,1) re_fx = fx.select(-1, 0) im_fx = fx.select(-1, 1) del px # Compute the gradient of fy, then back to temporal space grad_re_fy = torch.addcmul(grad_re_prod * re_fx, 1, grad_im_prod, im_fx) grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx) grad_fy = torch.irfft(torch.stack((grad_re_fy,grad_im_fy), grad_re_fy.dim()), 1, signal_sizes=(ctx.output_size,)) # Finally compute the gradient of y grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy) del re_fx,im_fx,grad_re_fy,grad_im_fy,grad_fy return None, None, None, None, None, grad_x, grad_y, None class CompactBilinearPooling(nn.Module): r"""Compute the compact bilinear pooling between two input array x and y .. math:: out = \Psi (x,h_1,s_1) \ast \Psi (y,h_2,s_2) Args: input_size1 (int): Number of channels in the first input array input_size2 (int): Number of channels in the second input array output_size (int): Number of channels in the output array h1 (array, optional): Optional array of size input_size of indices in the range [0,output_size] s1 (array, optional): Optional array of size input_size of -1 and 1. h2 (array, optional): Optional array of size input_size of indices in the range [0,output_size] s2 (array, optional): Optional array of size input_size of -1 and 1. force_cpu_scatter_add (boolean, optional): Force the scatter_add operation to run on CPU for testing purposes .. note:: If h1, s1, s2, h2 are None, they will be automatically be generated using LongTensor.random_. Shape: - Input 1: (...,input_size1) - Input 2: (...,input_size2) - Output: (...,output_size) References: Yang Gao et al. "Compact Bilinear Pooling" in Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (2016). Akira Fukui et al. "Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding", arXiv:1606.01847 (2016). """ def __init__(self, input1_size, input2_size, output_size, h1 = None, s1 = None, h2 = None, s2 = None, force_cpu_scatter_add=False): super(CompactBilinearPooling, self).__init__() self.add_module('sketch1', CountSketch(input1_size, output_size, h1, s1)) self.add_module('sketch2', CountSketch(input2_size, output_size, h2, s2)) self.output_size = output_size self.force_cpu_scatter_add = force_cpu_scatter_add def forward(self, x, y = None): if y is None: y = x return CompactBilinearPoolingFn.apply(self.sketch1.h, self.sketch1.s, self.sketch2.h, self.sketch2.s, self.output_size, x, y, self.force_cpu_scatter_add)