""" relu.py: Estimator and test for the rectified linear unit """ from __future__ import division from __future__ import print_function import numpy as np # Import other subpackages in vampyre import vampyre.common as common # Import individual classes and methods from the current subpackage from vampyre.estim.base import BaseEst from vampyre.estim.interval import gauss_integral class ReLUEst(BaseEst): """ Estimatar for a rectified linear unit :math:`z_1 = f(z_0) = \\max(0,z_0)` :param shape: shape of :math:`z_0` and :math:`z_1` :param var_axes: List :code:`[var_axes[0],var_axes[1]]` of the axes on which the input and output variances are averaged :param map_est: Flag indicating if estimation is MAP or MMSE. :param name: Estimator name. """ def __init__(self,shape,var_axes=[(0,),(0,)], name=None, map_est=False): self.map_est = map_est # Initial variances self.zvar0_init= np.Inf self.zvar1_init= np.Inf nvars = 2 dtype = np.float64 BaseEst.__init__(self,shape=[shape,shape], var_axes=var_axes, dtype=dtype, name=name,\ type_name='ReLUEst', nvars=nvars, cost_avail=True) def est_init(self,return_cost=False,ind_out=None, avg_var_cost=True): """ Initial estimator. See the base class :class:`vampyre.estim.base.Estim` for a complete description. :param Boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zmean, zvar, [cost]` which are the prior mean and variance """ # Check parameters if ind_out is None: ind_out = [0,1] if not avg_var_cost: raise ValueError("disabling variance averaging not supported for ReLUEST") zmean = [] zvar = [] if 0 in ind_out: zmean0 = np.zeros(self.shape[0]) zvar0_shape = common.utils.get_var_shape(self.shape[0], self.var_axes[0]) zvar0 = np.tile(self.zvar0_init, zvar0_shape) zmean.append(zmean0) zvar.append(zvar0) if 1 in ind_out: zmean1 = np.zeros(self.shape[1]) zvar1_shape = common.utils.get_var_shape(self.shape[1], self.var_axes[1]) zvar1 = np.tile(self.zvar1_init, zvar1_shape) zmean.append(zmean1) zvar.append(zvar1) cost = 0 if return_cost: return zmean, zvar, cost else: return zmean, zvar def est(self,r,rvar,return_cost=False,ind_out=None, avg_var_cost=True): """ Estimation function The proximal estimation function as described in the base class :class:`vampyre.estim.base.Estim` :param r: Proximal mean :param rvar: Proximal variance :param boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zhat, zhatvar, [cost]` which are the posterior mean, variance and optional cost. """ # Check parameters if ind_out is None: ind_out = [0,1] if not avg_var_cost: raise ValueError("disabling variance averaging not supported for ReLUEST") if self.map_est: return self.est_map(r,rvar,return_cost,ind_out) else: return self.est_mmse(r,rvar,return_cost,ind_out) def est_map(self,r,rvar,return_cost,ind_out): """ MAP Estimation In this case, we wish to minimize cost = (z0-r0)^2/(2*rvar0) + (z1-r1)^2/(2*rvar1) where z1 = max(0,z0) """ # Unpack the terms r0, r1 = r rvar0, rvar1 = rvar # Clip variances rvar1 = np.minimum(1e8*rvar0, rvar1) # Reshape the variances rvar0 = common.repeat_axes(rvar0,self.shape[0],self.var_axes[0]) rvar1 = common.repeat_axes(rvar1,self.shape[1],self.var_axes[1]) # Positive case: z0 >= 0 and hence z1=z0 z0p = np.maximum(0, (rvar0*r1 + rvar1*r0)/(rvar0 + rvar1)) z1p = z0p zvar0p = rvar0*rvar1/(rvar0+rvar1) zvar1p = zvar0p costp = 0.5*((z0p-r0)**2/rvar0 + (z1p-r1)**2/rvar1) # Negative case: z0 <= 0 and hence z1 = 0 z0n = np.minimum(0, r0) z1n = 0 zvar0n = rvar0 zvar1n = 0 costn = 0.5*((z0n-r0)**2/rvar0 + (z1n-r1)**2/rvar1) # Find lower cost and select the correct choice for each element Ip = (costp < costn) zhat0 = z0p*Ip + z0n*(1-Ip) zhat1 = z1p*Ip + z1n*(1-Ip) zhatvar0 = zvar0p*Ip + zvar0n*(1-Ip) zhatvar1 = zvar1p*Ip + zvar1n*(1-Ip) cost = np.sum(costp*Ip + costn*(1-Ip)) # Average the variance over the specified axes zhatvar0 = np.mean(zhatvar0,axis=self.var_axes[0]) zhatvar1 = np.mean(zhatvar1,axis=self.var_axes[1]) zhatvar = [zhatvar0,zhatvar1] # Pack the items zhat = [] zhatvar = [] if 0 in ind_out: zhat.append(zhat0) zhatvar.append(zhatvar0) if 1 in ind_out: zhat.append(zhat1) zhatvar.append(zhatvar1) if not return_cost: return zhat, zhatvar else: return zhat, zhatvar, cost def est_mmse(self,r,rvar,return_cost,ind_out): """ In the MMSE estimation case, we wish to estimate z0 and z1 with priors zi = N(ri,rvari) and z1=f(z0) Substituting in z1 = f(z0), we have the density of z0: p(z0) \propto qn(z0)1_{z0 < 0} + qp(z0)1_{z0 > 0} where qp(z0) = exp[-(z0-r0)^2/(2*rvar0) - (z0-r1)^2/(2*rvar1)] qn(z0) = exp[-(z0-r0)^2/(2*rvar0) - r1^2/(2*rvar1)] First, we complete the squares and write: qp(z0) = exp(Amax)*Cp*exp(-(z0-rp)^2/(2*zvarp))/sqrt(2*pi) qn(z0) = exp(Amax)*Cn*exp(-(z0-rn)^2/(2*zvarn))/sqrt(2*pi) """ # Unpack the terms r0, r1 = r rvar0, rvar1 = rvar # Reshape the variances rvar0 = common.repeat_axes(rvar0,self.shape[0],self.var_axes[0]) rvar1 = common.repeat_axes(rvar1,self.shape[1],self.var_axes[1]) if np.any(rvar1 == np.Inf): # Infinite variance case. zvarp = rvar0 zvarn = rvar0 rp = r0 rn = r0 Cp = 1 Cn = 1 Amax = 0 else: # Compute the MAP estimate zhat_map, zvar_map = self.est_map(r,rvar,return_cost=False,ind_out=[0,1]) zhat0_map, zhat1_map = zhat_map zvar0_map, zvar1_map = zvar_map # Compute the conditional Gaussian terms for z > 0 and z < 0 zvarp = rvar0*rvar1/(rvar0+rvar1) zvarn = rvar0 rp = (rvar1*r0 + rvar0*r1)/(rvar0+rvar1) rn = r0 # Compute scaling constants for each region Ap = 0.5*((rp**2)/zvarp - (r0**2)/rvar0 - (r1**2)/rvar1) An = 0.5*(-(r1**2)/rvar1) Amax = np.maximum(Ap,An) Ap = Ap - Amax An = An - Amax Cp = np.exp(Ap) Cn = np.exp(An) # Compute moments for each region zp = Cp*gauss_integral(0, np.Inf, rp, zvarp) zn = Cn*gauss_integral(-np.Inf, 0, rn, zvarn) # Find poorly conditioned points Ibad = (zp[0] + zn[0] < 1e-6) zpsum = zp[0] + zn[0] + Ibad # Compute mean zhat0 = (zp[1] + zn[1])/zpsum zhat1 = zp[1]/zpsum # Compute the variance zhatvar0 = (zp[2] + zn[2])/zpsum - zhat0**2 zhatvar1 = zp[2]/zpsum - zhat1**2 # Replace bad points with MAP estimate if 1: zhat0 = zhat0*(1-Ibad) + zhat0_map*Ibad zhat1 = zhat1*(1-Ibad) + zhat1_map*Ibad zhatvar0 = zhatvar0*(1-Ibad) + zvar0_map*Ibad zhatvar1 = zhatvar1*(1-Ibad) + zvar1_map*Ibad # Average the variance over the specified axes zhatvar0 = np.mean(zhatvar0,axis=self.var_axes[0]) zhatvar1 = np.mean(zhatvar1,axis=self.var_axes[1]) # Pack the items zhat = [] zhatvar = [] if 0 in ind_out: zhat.append(zhat0) zhatvar.append(zhatvar0) if 1 in ind_out: zhat.append(zhat1) zhatvar.append(zhatvar1) if not return_cost: return zhat, zhatvar """ Compute the cost = -\log \int p(z_0) = -Amax - log(zp[0] + zn[0]) """ nz = np.prod(self.shape[0]) cost = -nz*np.mean(Amax - np.log(zpsum)) return zhat, zhatvar, cost