import torch import math import torch.nn.init import torch.nn as nn from torch.autograd import Variable import torch.backends.cudnn as cudnn import numpy as np import torch.nn.functional as F class L2Norm(nn.Module): def __init__(self): super(L2Norm,self).__init__() self.eps = 1e-10 def forward(self, x): norm = torch.sqrt(torch.abs(torch.sum(x * x, dim = 1)) + self.eps) x= x / norm.unsqueeze(1).expand_as(x) return x def getPoolingKernel(kernel_size = 25): step = 1. / float(np.floor( kernel_size / 2.)); x_coef = np.arange(step/2., 1. ,step) xc2 = np.hstack([x_coef,[1], x_coef[::-1]]) kernel = np.outer(xc2.T,xc2) kernel = np.maximum(0,kernel) return kernel def get_bin_weight_kernel_size_and_stride(patch_size, num_spatial_bins): bin_weight_stride = int(round(2.0 * math.floor(patch_size / 2) / float(num_spatial_bins + 1))) bin_weight_kernel_size = int(2 * bin_weight_stride - 1); return bin_weight_kernel_size, bin_weight_stride class SIFTNet(nn.Module): def CircularGaussKernel(self,kernlen=21): halfSize = kernlen / 2; r2 = float(halfSize*halfSize); sigma2 = 0.9 * r2; disq = 0; kernel = np.zeros((kernlen,kernlen)) for y in range(kernlen): for x in range(kernlen): disq = (y - halfSize)*(y - halfSize) + (x - halfSize)*(x - halfSize); if disq < r2: kernel[y,x] = math.exp(-disq / sigma2) else: kernel[y,x] = 0. return kernel def __init__(self, patch_size = 65, num_ang_bins = 8, num_spatial_bins = 4, clipval = 0.2): super(SIFTNet, self).__init__() gk = torch.from_numpy(self.CircularGaussKernel(kernlen=patch_size).astype(np.float32)) self.bin_weight_kernel_size, self.bin_weight_stride = get_bin_weight_kernel_size_and_stride(patch_size, num_spatial_bins) self.gk = Variable(gk) self.num_ang_bins = num_ang_bins self.num_spatial_bins = num_spatial_bins self.clipval = clipval self.gx = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(1,3), bias = False)) for l in self.gx: if isinstance(l, nn.Conv2d): l.weight.data = torch.from_numpy(np.array([[[[-1, 0, 1]]]], dtype=np.float32)) self.gy = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(3,1), bias = False)) for l in self.gy: if isinstance(l, nn.Conv2d): l.weight.data = torch.from_numpy(np.array([[[[-1], [0], [1]]]], dtype=np.float32)) self.pk = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(self.bin_weight_kernel_size, self.bin_weight_kernel_size), stride = (self.bin_weight_stride, self.bin_weight_stride), bias = False)) for l in self.pk: if isinstance(l, nn.Conv2d): nw = getPoolingKernel(kernel_size = self.bin_weight_kernel_size) new_weights = np.array(nw.reshape((1, 1, self.bin_weight_kernel_size, self.bin_weight_kernel_size))) l.weight.data = torch.from_numpy(new_weights.astype(np.float32)) def forward(self, x): gx = self.gx(F.pad(x, (1,1,0, 0), 'replicate')) gy = self.gy(F.pad(x, (0,0, 1,1), 'replicate')) mag = torch.sqrt(gx **2 + gy **2 + 1e-10) ori = torch.atan2(gy,gx + 1e-8) if x.is_cuda: self.gk = self.gk.cuda() else: self.gk = self.gk.cpu() mag = mag * self.gk.expand_as(mag) o_big = (ori +2.0 * math.pi )/ (2.0 * math.pi) * float(self.num_ang_bins) bo0_big = torch.floor(o_big) wo1_big = o_big - bo0_big bo0_big = bo0_big % self.num_ang_bins bo1_big = (bo0_big + 1) % self.num_ang_bins wo0_big = (1.0 - wo1_big) * mag wo1_big = wo1_big * mag ang_bins = [] for i in range(0, self.num_ang_bins): ang_bins.append(self.pk((bo0_big == i).float() * wo0_big + (bo1_big == i).float() * wo1_big)) ang_bins = torch.cat(ang_bins,1) ang_bins = ang_bins.view(ang_bins.size(0), -1) ang_bins = L2Norm()(ang_bins) ang_bins = torch.clamp(ang_bins, 0.,float(self.clipval)) ang_bins = L2Norm()(ang_bins) return ang_bins