设A是(nxm) - 矩阵和M an(mxm) - 矩阵。编写tr()用于跟踪矩阵,我需要计算tr(AM(A ^ T))。但是,最终的跟踪操作会丢弃大部分计算。我可以使用numpy's或pytorch的广播规则来仅计算AM(A ^ T)的必要对角线吗?
更新 这是我在PyTorch中计算对角线的解决方案:
torch.sum(torch.sum(A.t()[:,None,:]*M[:,:,None],0)*A.t(),0)
答案 0 :(得分:0)
您必须至少计算两种矩阵产品中的一种。随后您可以使用以下答案之一:What is the best way to compute the trace of a matrix product in numpy?