使用numpy / pytorch广播计算矩阵乘积的轨迹

时间:2018-04-24 01:33:54

标签: python pytorch numpy-broadcasting

设A是(nxm) - 矩阵和M an(mxm) - 矩阵。编写tr()用于跟踪矩阵,我需要计算tr(AM(A ^ T))。但是,最终的跟踪操作会丢弃大部分计算。我可以使用numpy's或pytorch的广播规则来仅计算AM(A ^ T)的必要对角线吗?

更新 这是我在PyTorch中计算对角线的解决方案:

torch.sum(torch.sum(A.t()[:,None,:]*M[:,:,None],0)*A.t(),0)

1 个答案:

答案 0 :(得分:0)

您必须至少计算两种矩阵产品中的一种。随后您可以使用以下答案之一:What is the best way to compute the trace of a matrix product in numpy?