如何在Tensorflow中计算矩阵乘积的对角线?

时间:2017-08-25 14:44:03

标签: python matrix tensorflow diagonal

我有两个AB形状(M, N)的矩阵,其MN非常大。

我想将它们相乘,然后取结果的对角线:

C = tf.matmul(A, B)
D = tf.diag_part(C)

不幸的是,这需要创建非常大的(M, M)矩阵,这种矩阵无法适应内存。

但我不需要这些数据。那么,是否可以一步计算这个值?

是否有类似einsum但没有求和的东西?

2 个答案:

答案 0 :(得分:3)

您需要的是:

tf.einsum('ij,ij->i', A, B)

或:

tf.reduce_sum(A * B, axis=1)

实施例

A = tf.constant([[1,2],[2,3],[3,4]])
B = tf.constant([[3,4],[1,2],[2,3]])

with tf.Session() as sess:
    print(sess.run(tf.diag_part(tf.matmul(A, B, transpose_b=True)))) 
# [11  8 18]

with tf.Session() as sess:
    print(sess.run(tf.reduce_sum(A * B, axis=1)))
#[11  8 18]

with tf.Session() as sess:
    print(sess.run(tf.einsum('ij,ij->i', A, B)))
#[11  8 18]

答案 1 :(得分:1)

您可以使用dot product AB transpose来获取相同内容:

tf.reduce_sum(tf.multiply(A, tf.transpose(B)), axis=1)

代码:

import tensorflow as tf
import numpy as np

A = tf.constant([[1,4, 3], [4, 2, 6]])
B = tf.constant([[5,4,],[8,5], [7, 3]])

E = tf.reduce_sum(tf.multiply(A, tf.transpose(B)), axis=1)

C = tf.matmul(A, B)
D = tf.diag_part(C)
sess = tf.InteractiveSession()

print(sess.run(D))
print(sess.run(E))

#Output
#[58 44]
#[58 44]