import torch import numpy as np def mpjpe(predicted, target): """ Mean per-joint position error (i.e. mean Euclidean distance), often referred to as "Protocol #1" in many papers. """ assert predicted.shape == target.shape return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1)) def weighted_mpjpe(predicted, target, w): """ Weighted mean per-joint position error (i.e. mean Euclidean distance) """ assert predicted.shape == target.shape assert w.shape[0] == predicted.shape[0] return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape)-1)) def p_mpjpe_torch(predicted, target, with_sRt=False,full_torch=False,with_aligned=False): """ Pose error: MPJPE after rigid alignment (scale, rotation, and translation), often referred to as "Protocol #2" in many papers. """ assert predicted.shape == target.shape muX = torch.mean(target, dim=1, keepdim=True) muY = torch.mean(predicted, dim=1, keepdim=True) #print(predicted, target) X0 = target - muX Y0 = predicted - muY X0[X0**2<1e-6]=1e-3 normX = torch.sqrt(torch.sum(X0**2, dim=(1, 2), keepdim=True)) normY = torch.sqrt(torch.sum(Y0**2, dim=(1, 2), keepdim=True)) normX[normX<1e-3]=1e-3 X0 /= normX Y0 /= normY H = torch.matmul(X0.transpose(1,2), Y0) if full_torch: U, s, V = batch_svd(H) else: U, s, Vt = np.linalg.svd(H.cpu().numpy()) V = torch.from_numpy(Vt.transpose(0, 2, 1)).cuda() U = torch.from_numpy(U).cuda() s = torch.from_numpy(s).cuda() #U, s, V = U.unsqueeze(0), s.unsqueeze(0), V.unsqueeze(0) #V = Vt.transpose(2, 1) R = torch.matmul(V, U.transpose(2, 1)) # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 sign_detR = torch.sign(torch.unsqueeze(torch.det(R[0]), 0)) V[:, :, -1] *= sign_detR.unsqueeze(0) s[:, -1] *= sign_detR.flatten() R = torch.matmul(V, U.transpose(2, 1)) # Rotation tr = torch.unsqueeze(torch.sum(s, dim=1, keepdim=True), 2) a = tr * normX / normY # Scale t = muX - a*torch.matmul(muY, R) # Translation if (a!=a).sum()>0: print('NaN Error!!') print('UsV:',U,s,V) print('aRt:',a,R,t) a[a!=a]=1. R[R!=R]=0. t[t!=t]=0. # Perform rigid transformation on the input predicted_aligned = a*torch.matmul(predicted, R) + t if with_sRt: return torch.sqrt(((predicted_aligned - target)**2).sum(-1)).mean(),(a,R,t)#torch.mean(torch.norm(predicted_aligned - target, dim=len(target.shape)-1)) if with_aligned: return torch.sqrt(((predicted_aligned - target)**2).sum(-1)).mean(),predicted_aligned # Return MPJPE return torch.sqrt(((predicted_aligned - target)**2).sum(-1)).mean()#torch.mean(torch.norm(predicted_aligned - target, dim=len(target.shape)-1))#,(a,R,t),predicted_aligned def batch_svd(H): num = H.shape[0] U_batch, s_batch, V_batch = [],[],[] for i in range(num): U, s, V = H[i].svd(some=False) U_batch.append(U.unsqueeze(0)) s_batch.append(s.unsqueeze(0)) V_batch.append(V.unsqueeze(0)) return torch.cat(U_batch,0),torch.cat(s_batch,0),torch.cat(V_batch,0) def p_mpjpe(predicted, target, with_sRt=False,full_torch=False,with_aligned=False,each_separate=False): """ Pose error: MPJPE after rigid alignment (scale, rotation, and translation), often referred to as "Protocol #2" in many papers. """ assert predicted.shape == target.shape muX = np.mean(target, axis=1, keepdims=True) muY = np.mean(predicted, axis=1, keepdims=True) X0 = target - muX Y0 = predicted - muY ''' if (X0**2<1e-10).sum()>0 or (X0**2>1e10).sum()>0: print('Error !') print(X0[X0**2<1e-10],X0[X0**2>1e10]) print(predicted[X0**2<1e-10],predicted[X0**2>1e10]) return 1.,(np.ones(3),np.ones((3,3)),np.ones(3)) ''' normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True)) normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True)) X0 /= (normX+1e-6) Y0 /= (normY+1e-6) H = np.matmul(X0.transpose(0, 2, 1), Y0).astype(np.float16).astype(np.float64) U, s, Vt = np.linalg.svd(H) V = Vt.transpose(0, 2, 1) R = np.matmul(V, U.transpose(0, 2, 1)) # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1)) V[:, :, -1] *= sign_detR s[:, -1] *= sign_detR.flatten() R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) a = tr * normX / normY # Scale t = muX - a*np.matmul(muY, R) # Translation # Perform rigid transformation on the input predicted_aligned = a*np.matmul(predicted, R) + t if with_sRt and not with_aligned: return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1)),(a,R,t) if with_aligned: return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1)),(a,R,t),predicted_aligned if each_separate: return np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1) # Return MPJPE return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1)) def n_mpjpe(predicted, target): """ Normalized MPJPE (scale only), adapted from: https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py """ assert predicted.shape == target.shape norm_predicted = torch.mean(torch.sum(predicted**2, dim=3, keepdim=True), dim=2, keepdim=True) norm_target = torch.mean(torch.sum(target*predicted, dim=3, keepdim=True), dim=2, keepdim=True) scale = norm_target / norm_predicted return mpjpe(scale * predicted, target) def mean_velocity_error(predicted, target): """ Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative) """ assert predicted.shape == target.shape velocity_predicted = np.diff(predicted, axis=0) velocity_target = np.diff(target, axis=0) return np.mean(np.linalg.norm(velocity_predicted - velocity_target, axis=len(target.shape)-1)) def test(): for i in range(100): r1 = np.random.rand(3,14,3) r2 = np.random.rand(3,14,3) pmpjpe = p_mpjpe(r1, r2,with_sRt=False) pmpjpe_torch = p_mpjpe_torch(torch.from_numpy(r1), torch.from_numpy(r2),with_sRt=False,full_torch=True) print('pmpjpe: {}; {:.6f}; {:.6f}; {:.6f}'.format(pmpjpe==pmpjpe_torch.numpy(),pmpjpe,pmpjpe_torch.numpy(), pmpjpe-pmpjpe_torch.numpy())) ''' pmpjpe,(s,R,t),(H,U, s, Vt) = p_mpjpe(r1, r2,with_sRt=True) pmpjpe_torch,(s_torch,R_torch,t_torch),(H_torch,U_torch, s_torch, Vt_torch) = p_mpjpe_torch(torch.from_numpy(r1), torch.from_numpy(r2),with_sRt=True,full_torch=True) print('s:',s==s_torch.numpy(),s,s_torch.numpy()) print('R:',R==R_torch.numpy(),R,R_torch.numpy()) print('t:',t==t_torch.numpy(),t,t_torch.numpy()) print(H) print(H_torch) print(U) print(U_torch) print(Vt) print(Vt_torch) print(s) print(s_torch) ''' if __name__ == '__main__': test()