Tensorflow:块对角矩阵的运算

时间:2019-01-09 19:47:02

标签: performance tensorflow memory-efficient

我正在寻找一种在Tensorflow中实现块对角矩阵的方法。具体来说,我有块对角矩阵A,每个块都有N个大小为S x S的块。此外,我有一个长度为N * S的向量v。我想计算一个点v。在Tensorflow中有有效的方法吗?

此外,我更喜欢支持v的批处理维(例如,其实际维为batch_size x(N * S))并且实现内存高效的实现,仅将A的块对角部分保留在内存中。

感谢您的帮助!

1 个答案:

答案 0 :(得分:0)

您可以简单地将张量转换为sparse tensor,因为块对角矩阵只是其中的特例。然后,以有效的方式完成操作。如果您已经有了张量的密集表示,则可以使用sparse_tensor = tf.contrib.layers.dense_to_sparse(dense_tensor)进行转换。否则,您可以使用tf.SparseTensor(...)函数来构造它。要获取索引,您可以使用tf.strided_slice,有关更多信息,请参见this post