Python torch.symeig() Examples

The following are 22 code examples of torch.symeig(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module , or try the search function .
Example #1
def __expm__(self, matrix, symmetric):
r"""Calculates matrix exponential.

Args:
matrix (Tensor): Matrix to take exponential of.
symmetric (bool): Specifies whether the matrix is symmetric.

:rtype: (:class:Tensor)
"""
if symmetric:
e, V = torch.symeig(matrix, eigenvectors=True)
diff_mat = V @ torch.diag(e.exp()) @ V.t()
else:
diff_mat_np = expm(matrix.cpu().numpy())
diff_mat = torch.Tensor(diff_mat_np).to(matrix.device)
return diff_mat 
Example #2
def symsqrt(a, cond=None, return_rank=False, dtype=torch.float32):
"""Symmetric square root of a positive semi-definite matrix.
See https://github.com/pytorch/pytorch/issues/25481"""

s, u = torch.symeig(a, eigenvectors=True)
cond_dict = {torch.float32: 1e3 * 1.1920929e-07, torch.float64: 1E6 * 2.220446049250313e-16}

if cond in [None, -1]:
cond = cond_dict[dtype]

above_cutoff = (abs(s) > cond * torch.max(abs(s)))

psigma_diag = torch.sqrt(s[above_cutoff])
u = u[:, above_cutoff]

B = u @ torch.diag(psigma_diag) @ u.t()
if return_rank:
return B, len(psigma_diag)
else:
return B 
Example #3
def lanczos_tridiag_to_diag(t_mat):
"""
Given a num_init_vecs x num_batch x k x k tridiagonal matrix t_mat,
returns a num_init_vecs x num_batch x k set of eigenvalues
and a num_init_vecs x num_batch x k x k set of eigenvectors.

TODO: make the eigenvalue computations done in batch mode.
"""
orig_device = t_mat.device
if t_mat.size(-1) < 32:
retr = torch.symeig(t_mat.cpu(), eigenvectors=True)
else:
retr = torch.symeig(t_mat, eigenvectors=True)

evals, evecs = retr

return evals.to(orig_device), evecs.to(orig_device) 
Example #4
def test_forward():
torch.manual_seed(42)
N = 100
tol = 1E-8
dtype = torch.float64
A = torch.randn(N, N, dtype=dtype)
A = A+A.t()

w, v = torch.symeig(A, eigenvectors=True)
idx = torch.argmax(w.abs())

v_exact = v[:, idx]
v_exact = v_exact[0].sign() * v_exact

x0 = torch.rand(N, dtype=dtype)
x0 = x0/x0.norm()
x = FixedPoint.apply(A, x0, tol)

assert(torch.allclose(v_exact, x, rtol=tol, atol=tol)) 
Example #5
def active_subspace(t):
"""
Compute the main variational directions of a tensor.

Reference: P. Constantine et al. "Discovering an Active Subspace in a Single-Diode Solar Cell Model" (2017) <https://arxiv.org/pdf/1406.7607.pdf>_

See also P. Constantine's data set repository <https://github.com/paulcon/as-data-sets/blob/master/README.md>_.

:param t: input tensor
:return: (eigvals, eigvecs): an array and a matrix, encoding the eigenpairs in descending order
"""

M = torch.zeros(t.dim(), t.dim())
for i in range(t.dim()):
for j in range(i, t.dim()):
M[j, i] = M[i, j]

w, v = torch.symeig(M, eigenvectors=True)
idx = range(t.dim()-1, -1, -1)
w = w[idx]
v = v[:, idx]
return w, v 
Example #6
def pca(data, components):
"""
Finds the components top principal components of the data.
"""
assert components > 0 and components < data.size(1), "incorrect # of PCA dimensions"
# We switch to numpy here as torch.symeig gave strange results.
dtype = data.dtype
data = data.numpy()
data -= np.mean(data, axis=0, keepdims=True)
cov = np.cov(data.T)
L, V = nla.eigh(cov)
return torch.tensor(V[:, -components:], dtype=dtype) 
Example #7
def _update_inv(self):
assert self.steps > 0, 'At least one step before update inverse!'
eps = 1e-15
for idx, m in enumerate(self.modules):
# m_aa, m_gg = normalize_factors(self.m_aa[m], self.m_gg[m])
m_aa, m_gg = self.m_aa[m], self.m_gg[m]
self.d_a[m], self.Q_a[m] = torch.symeig(m_aa / self.steps, eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(m_gg / self.steps, eigenvectors=True)
self.d_a[m].mul_((self.d_a[m] > eps).float())
self.d_g[m].mul_((self.d_g[m] > eps).float())

self._inversed = True
self.iter += 1 
Example #8
def _update_inv(self):
assert self.steps > 0, 'At least one step before update inverse!'
eps = 1e-10
for idx, m in enumerate(self.modules):
m_aa, m_gg = self.m_aa[m], self.m_gg[m]
self.d_a[m], Q_a = torch.symeig(m_aa, eigenvectors=True)
self.d_g[m], Q_g = torch.symeig(m_gg, eigenvectors=True)

self.d_a[m].mul_((self.d_a[m] > eps).float())
self.d_g[m].mul_((self.d_g[m] > eps).float())

# == write summary ==
name = m.__class__.__name__
eigs = (self.d_g[m].view(-1, 1) @ self.d_a[m].view(1, -1)).view(-1).cpu().data.numpy()
self.writer.add_histogram('eigen/%s_%d' % (name, idx), eigs, self.iter)

if self.Q_a.get(m, None) is None:
# print('(%d)Q_a %s is None.' % (idx, m))
self.Q_a[m] = [Q_a]  # absorb the eigen basis
else:
# self.Q_a[m] = [Q_a, self.Q_a[m]]
prev_Q_a, prev_Q_g = get_rotation_layer_weights(self.model, m)
prev_Q_a = prev_Q_a.view(prev_Q_a.size(0), prev_Q_a.size(1)).transpose(1, 0)
prev_Q_g = prev_Q_g.view(prev_Q_g.size(0), prev_Q_g.size(1))
self.Q_a[m] = [Q_a, prev_Q_a]

if self.Q_g.get(m, None) is None:
self.Q_g[m] = [Q_g]
else:
self.Q_g[m] = [Q_g, prev_Q_g]
self._inversed = True
self.iter += 1 
Example #9
def test_precond_solve(self):
seed = 4
torch.random.manual_seed(seed)

tensor = torch.randn(1000, 800)
diag = torch.abs(torch.randn(1000))

evals, evecs = torch.symeig(standard_lt.evaluate(), eigenvectors=True)

# this preconditioner is a simple example of near deflation
def nonstandard_preconditioner(self):
top_100_evecs = evecs[:, :100]
top_100_evals = evals[:100] + 0.2 * torch.randn(100)

precond_lt = RootLazyTensor(top_100_evecs @ torch.diag(top_100_evals ** 0.5))
logdet = top_100_evals.log().sum()

def precond_closure(rhs):
rhs2 = top_100_evecs.t() @ rhs

return precond_closure, precond_lt, logdet

RootLazyTensor(tensor), DiagLazyTensor(diag), preconditioner_override=nonstandard_preconditioner
)

# compute a solve - mostly to make sure that we can actually perform the solve
rhs = torch.randn(1000, 1)
standard_solve = standard_lt.inv_matmul(rhs)
overrode_solve = overrode_lt.inv_matmul(rhs)

# gut checking that our preconditioner is not breaking anything
self.assertEqual(standard_solve.shape, overrode_solve.shape)
self.assertLess(torch.norm(standard_solve - overrode_solve) / standard_solve.norm(), 1.0) 
Example #10
def forward(self, A):
w, v = torch.symeig(A, eigenvectors=True)

self.save_for_backward(w, v)
return w, v 
Example #11
def compute_tests(genotypes_t, var_thresh=0.99, variant_window=200):
"""determine effective number of independent variants (M_eff)"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# break into windows
windows = torch.split(genotypes_t, variant_window)

if len(windows)>1:
shrunk_cov_t, shrinkage_t = lw_shrink(torch.transpose(torch.stack(windows[:-1]), 1, 2))

n_samples, n_features = windows[0].T.shape
# indices of diagonals
ix = torch.LongTensor(np.array([np.arange(0, n_features**2, n_features+1)+i*n_features**2 for i in range(shrunk_cov_t.shape[0])])).to(device)
shrunk_precision_t = torch.zeros(shrunk_cov_t.shape).to(device)
shrunk_precision_t.view(-1)[ix] = shrunk_cov_t.view(-1)[ix].pow(-0.5)
shrunk_cor_t = torch.matmul(torch.matmul(shrunk_precision_t, shrunk_cov_t), shrunk_precision_t)
eigenvalues_t,_ = torch.symeig(shrunk_cor_t, eigenvectors=False)

# last window
shrunk_cov0_t, shrinkage0_t = lw_shrink(windows[-1].t())
shrunk_precision0_t = torch.diag(torch.diag(shrunk_cov0_t).pow(-0.5))
shrunk_cor0_t = torch.mm(torch.mm(shrunk_precision0_t, shrunk_cov0_t), shrunk_precision0_t)
eigenvalues0_t,_ = torch.symeig(shrunk_cor0_t, eigenvectors=False)

if len(windows)>1:
eigenvalues = list(eigenvalues_t.cpu().numpy())
eigenvalues.append(eigenvalues0_t.cpu().numpy())
else:
eigenvalues = [eigenvalues0_t.cpu().numpy()]

m_eff = 0
for ev,m in zip(eigenvalues, [i.shape[0] for i in windows]):
ev[ev < 0] = 0
m_eff += find_num_eigs(ev, m, var_thresh=var_thresh)

return m_eff 
Example #12
def _compute_kfe(self, group, state):
"""Computes the covariances."""
mod = group['mod']
x = self.state[group['mod']]['x']
gy = self.state[group['mod']]['gy']
# Computation of xxt
if group['layer_type'] == 'Conv2d':
if not self.sua:
x = F.conv2d(x, group['gathering_filter'],
groups=mod.in_channels)
x = x.data.permute(1, 0, 2, 3).contiguous().view(x.shape[1], -1)
else:
x = x.data.t()
if mod.bias is not None:
ones = torch.ones_like(x[:1])
x = torch.cat([x, ones], dim=0)
xxt = torch.mm(x, x.t()) / float(x.shape[1])
Ex, state['kfe_x'] = torch.symeig(xxt, eigenvectors=True)
# Computation of ggt
if group['layer_type'] == 'Conv2d':
gy = gy.data.permute(1, 0, 2, 3)
state['num_locations'] = gy.shape[2] * gy.shape[3]
gy = gy.contiguous().view(gy.shape[0], -1)
else:
gy = gy.data.t()
state['num_locations'] = 1
ggt = torch.mm(gy, gy.t()) / float(gy.shape[1])
Eg, state['kfe_gy'] = torch.symeig(ggt, eigenvectors=True)
state['m2'] = Eg.unsqueeze(1) * Ex.unsqueeze(0) * state['num_locations']
if group['layer_type'] == 'Conv2d' and self.sua:
state['m2'] = state['m2'].view(Eg.size(0), Ex.size(0), 1, 1).expand(-1, -1, ws[2], ws[3]) 
Example #13
def __init__(
self,
mean: Tensor,
cov: Tensor,
seed: Optional[int] = None,
inv_transform: bool = False,
) -> None:
r"""Engine for qMC sampling from a multivariate Normal N(\mu, \Sigma).

Args:
mean: The mean vector.
cov: The covariance matrix.
seed: The seed with which to seed the random number generator of the
underlying SobolEngine.
inv_transform: If True, use inverse transform instead of Box-Muller.
"""
# validate inputs
if not cov.shape[0] == cov.shape[1]:
raise ValueError("Covariance matrix is not square.")
if not mean.shape[0] == cov.shape[0]:
raise ValueError("Dimension mismatch between mean and covariance.")
if not torch.allclose(cov, cov.transpose(-1, -2)):
raise ValueError("Covariance matrix is not symmetric.")
self._mean = mean
self._normal_engine = NormalQMCEngine(
d=mean.shape[0], seed=seed, inv_transform=inv_transform
)
# compute Cholesky decomp; if it fails, do the eigendecomposition
try:
self._corr_matrix = torch.cholesky(cov).transpose(-1, -2)
except RuntimeError:
eigval, eigvec = torch.symeig(cov, eigenvectors=True)
if not torch.all(eigval >= -1e-8):
raise ValueError("Covariance matrix not PSD.")
eigval_root = eigval.clamp_min(0.0).sqrt()
self._corr_matrix = (eigvec * eigval_root).transpose(-1, -2) 
Example #14
def step(self):
if self.weight_decay > 0:
for p in self.model.parameters():

for i, m in enumerate(self.modules):
assert len(list(m.parameters())
) == 1, "Can handle only one parameter at the moment"
classname = m.__class__.__name__
p = next(m.parameters())

la = self.damping + self.weight_decay

if self.steps % self.Tf == 0:
# My asynchronous implementation exists, I will add it later.
# Experimenting with different ways to this in PyTorch.
self.d_a[m], self.Q_a[m] = torch.symeig(
self.m_aa[m], eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(
self.m_gg[m], eigenvectors=True)

self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())

if classname == 'Conv2d':
else:

v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
v2 = v1 / (
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()

vg_sum = 0
for p in self.model.parameters():
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()

nu = min(1, math.sqrt(self.kl_clip / vg_sum))

for p in self.model.parameters():

self.optim.step()
self.steps += 1 
Example #15
def step(self):
if self.weight_decay > 0:
for p in self.model.parameters():

for i, m in enumerate(self.modules):
assert len(list(m.parameters())
) == 1, "Can handle only one parameter at the moment"
classname = m.__class__.__name__
p = next(m.parameters())

la = self.damping + self.weight_decay

if self.steps % self.Tf == 0:
# My asynchronous implementation exists, I will add it later.
# Experimenting with different ways to this in PyTorch.
self.d_a[m], self.Q_a[m] = torch.symeig(
self.m_aa[m], eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(
self.m_gg[m], eigenvectors=True)

self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())

if classname == 'Conv2d':
else:

v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
v2 = v1 / (
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()

vg_sum = 0
for p in self.model.parameters():
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()

nu = min(1, math.sqrt(self.kl_clip / vg_sum))

for p in self.model.parameters():

self.optim.step()
self.steps += 1 
Example #16
def step(self):
if self.weight_decay > 0:
for p in self.model.parameters():

for i, m in enumerate(self.modules):
assert len(list(m.parameters())
) == 1, "Can handle only one parameter at the moment"
classname = m.__class__.__name__
p = next(m.parameters())

la = self.damping + self.weight_decay

if self.steps % self.Tf == 0:
# My asynchronous implementation exists, I will add it later.
# Experimenting with different ways to this in PyTorch.
self.d_a[m], self.Q_a[m] = torch.symeig(
self.m_aa[m], eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(
self.m_gg[m], eigenvectors=True)

self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())

if classname == 'Conv2d':
else:

v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
v2 = v1 / (
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()

vg_sum = 0
for p in self.model.parameters():
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()

nu = min(1, math.sqrt(self.kl_clip / vg_sum))

for p in self.model.parameters():

self.optim.step()
self.steps += 1 
Example #17
def step(self):
if self.weight_decay > 0:
for p in self.model.parameters():

for i, m in enumerate(self.modules):
assert len(list(m.parameters())
) == 1, "Can handle only one parameter at the moment"
classname = m.__class__.__name__
p = next(m.parameters())

la = self.damping + self.weight_decay

if self.steps % self.Tf == 0:
# My asynchronous implementation exists, I will add it later.
# Experimenting with different ways to this in PyTorch.
self.d_a[m], self.Q_a[m] = torch.symeig(
self.m_aa[m], eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(
self.m_gg[m], eigenvectors=True)

self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())

if classname == 'Conv2d':
else:

v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
v2 = v1 / (
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()

vg_sum = 0
for p in self.model.parameters():
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()

nu = min(1, math.sqrt(self.kl_clip / vg_sum))

for p in self.model.parameters():

self.optim.step()
self.steps += 1 
Example #18
def step(self):
if self.weight_decay > 0:
for p in self.model.parameters():

for i, m in enumerate(self.modules):
assert len(list(m.parameters())
) == 1, "Can handle only one parameter at the moment"
classname = m.__class__.__name__
p = next(m.parameters())

la = self.damping + self.weight_decay

if self.steps % self.Tf == 0:
# My asynchronous implementation exists, I will add it later.
# Experimenting with different ways to this in PyTorch.
self.d_a[m], self.Q_a[m] = torch.symeig(
self.m_aa[m], eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(
self.m_gg[m], eigenvectors=True)

self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())

if classname == 'Conv2d':
else:

v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
v2 = v1 / (
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()

vg_sum = 0
for p in self.model.parameters():
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()

nu = min(1, math.sqrt(self.kl_clip / vg_sum))

for p in self.model.parameters():

self.optim.step()
self.steps += 1 
Example #19
def GDPPLoss(phiFake, phiReal, backward=True):
r"""
Implementation of the GDPP loss. Can be used with any kind of GAN
architecture.

Args:

phiFake (tensor) : last feature layer of the discriminator on real data
phiReal (tensor) : last feature layer of the discriminator on fake data
backward (bool)  : should we perform the backward operation ?

Returns:

Loss's value. The backward operation in performed within this operator
"""
def compute_diversity(phi):
phi = F.normalize(phi, p=2, dim=1)
SB = torch.mm(phi, phi.t())
eigVals, eigVecs = torch.symeig(SB, eigenvectors=True)
return eigVals, eigVecs

def normalize_min_max(eigVals):
minV, maxV = torch.min(eigVals), torch.max(eigVals)
if abs(minV - maxV) < 1e-10:
return eigVals
return (eigVals - minV) / (maxV - minV)

fakeEigVals, fakeEigVecs = compute_diversity(phiFake)
realEigVals, realEigVecs = compute_diversity(phiReal)

# Scaling factor to make the two losses operating in comparable ranges.
magnitudeLoss = 0.0001 * F.mse_loss(target=realEigVals, input=fakeEigVals)
structureLoss = -torch.sum(torch.mul(fakeEigVecs, realEigVecs), 0)
normalizedRealEigVals = normalize_min_max(realEigVals)
weightedStructureLoss = torch.sum(
torch.mul(normalizedRealEigVals, structureLoss))
gdppLoss = magnitudeLoss + weightedStructureLoss

if backward:
gdppLoss.backward(retain_graph=True)

return gdppLoss.item() 
Example #20
def finalize(self):
"""
Finalize training with LU factorization or Pseudo-inverse
"""
# Reshape average
xTx, avg, tlen = self._fix(self.xTx, self.xTx_avg, self.tlen)

# Reshape
self.avg = avg.unsqueeze(0)

# We need more observations than variables
if self.tlen < self.input_dim:
raise Exception(u"The number of observations ({}) is larger than  the number of input variables ({})".format(self.tlen, self.input_dim))
# end if

# Total variance
total_var = torch.diag(xTx).sum()

# Compute and sort eigenvalues
d, v = torch.symeig(xTx, eigenvectors=True)

# Check for negative eigenvalues
if float(d.min()) < 0:
# raise Exception(u"Got negative eigenvalues ({}). You may either set output_dim to be smaller".format(d))
pass
# end if

# Indexes
indexes = range(d.size(0)-1, -1, -1)

# Sort by descending order
d = torch.take(d, Variable(torch.LongTensor(indexes)))
v = v[:, indexes]

# Explained covariance
self.explained_variance = torch.sum(d) / total_var

# Store eigenvalues
self.d = d[:self.output_dim]

# Store eigenvectors
self.v = v[:, :self.output_dim]

# Total variance
self.total_variance = total_var

# Stop training
self.train(False)
# end finalize

# Get explained variance 
Example #21
def step(self):
if self.weight_decay > 0:
for p in self.model.parameters():

for i, m in enumerate(self.modules):
assert len(list(m.parameters())
) == 1, "Can handle only one parameter at the moment"
classname = m.__class__.__name__
p = next(m.parameters())

la = self.damping + self.weight_decay

if self.steps % self.Tf == 0:
# My asynchronous implementation exists, I will add it later.
# Experimenting with different ways to this in PyTorch.
self.d_a[m], self.Q_a[m] = torch.symeig(
self.m_aa[m], eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(
self.m_gg[m], eigenvectors=True)

self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())

if classname == 'Conv2d':
else:

v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
v2 = v1 / (
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()

vg_sum = 0
for p in self.model.parameters():
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()

nu = min(1, math.sqrt(self.kl_clip / vg_sum))

for p in self.model.parameters():

self.optim.step()
self.steps += 1 
Example #22
def step(self):
if self.weight_decay > 0:
for p in self.model.parameters():

for i, m in enumerate(self.modules):
assert len(list(m.parameters())
) == 1, "Can handle only one parameter at the moment"
classname = m.__class__.__name__
p = next(m.parameters())

la = self.damping + self.weight_decay

if self.steps % self.Tf == 0:
# My asynchronous implementation exists, I will add it later.
# Experimenting with different ways to this in PyTorch.
self.d_a[m], self.Q_a[m] = torch.symeig(
self.m_aa[m], eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(
self.m_gg[m], eigenvectors=True)

self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())

if classname == 'Conv2d':
else:

v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
v2 = v1 / (
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()

vg_sum = 0
for p in self.model.parameters():
continue
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()

nu = min(1, math.sqrt(self.kl_clip / vg_sum))

for p in self.model.parameters():
self.steps += 1