import torch from torch import nn from torch.autograd import Variable P = Variable(torch.FloatTensor(1).fill_(10)) expP = Variable(torch.exp(P.data)) negExpP = Variable(torch.exp(-P.data)) def preProc1(x): # Access the global variables global P,expP,negExpP P = P.type_as(x) expP = expP.type_as(x) negExpP = negExpP.type_as(x) # Create a variable filled with -1. Second part of the condition z = Variable(torch.zeros(x.size()).fill_(-1)).type_as(x) absX = torch.abs(x) cond1 = torch.gt(absX, negExpP) if (torch.sum(cond1) > 0).data.all(): x1 = torch.log(torch.abs(x[cond1]))/P z[cond1] = x1 return z def preProc2(x): # Access the global variables global P, expP, negExpP P = P.type_as(x) expP = expP.type_as(x) negExpP = negExpP.type_as(x) # Create a variable filled with -1. Second part of the condition z = Variable(torch.zeros(x.size())).type_as(x) absX = torch.abs(x) cond1 = torch.gt(absX, negExpP) cond2 = torch.le(absX, negExpP) if (torch.sum(cond1) > 0).data.all(): x1 = torch.sign(x[cond1]) z[cond1] = x1 if (torch.sum(cond2) > 0).data.all(): x2 = x[cond2]*expP z[cond2] = x2 return z def preprocess(grad,loss): #preGrad = Variable(grad.data.new(grad.data.size()[0], 1, 2).zero_()) #preGrad = grad.expand(grad.data.size()[0], 1, 2) preGrad = grad.clone().expand(grad.data.size()[0], 1, 2) preGrad[:, :, 0] = preProc1(grad) preGrad[:, :, 1] = preProc2(grad) #lossT = Variable(loss.data.new(1,1,1).zero_()) #lossT[0] = loss #preLoss = Variable(loss.data.new(1,1,2).zero_()) #preLoss = loss.expand(1, 1, 2) preLoss = loss.clone().expand(1, 1, 2) preLoss[:, :, 0] = preProc1(loss) preLoss[:, :, 1] = preProc2(loss) return preGrad,preLoss