''' PyTorch has its own implementation of backward function for symmetric eigensolver https://github.com/pytorch/pytorch/blob/291746f11047361100102577ce7d1cfa1833be50/tools/autograd/templates/Functions.cpp#L1660 However, it assumes a triangular adjoint. We reimplement it to return a symmetric adjoint ''' import numpy as np import torch class EigenSolver(torch.autograd.Function): @staticmethod def forward(self, A): w, v = torch.symeig(A, eigenvectors=True) self.save_for_backward(w, v) return w, v @staticmethod def backward(self, dw, dv): w, v = self.saved_tensors dtype, device = w.dtype, w.device N = v.shape[0] F = w - w[:,None] F.diagonal().fill_(np.inf) # safe inverse msk = (torch.abs(F) < 1e-20) F[msk] += 1e-20 F = 1./F vt = v.t() vdv = vt@dv return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt