tensorflow:将矩阵的某些行与另一列中的某些列相乘

时间:2017-08-18 05:29:12

标签: python tensorflow matrix-multiplication

假设我有一个矩阵A和一个矩阵B。我知道tf.matmul(A,B)可以计算两个矩阵的乘法。但我的任务只需要将A的某些行与某些B列相乘。

例如,我有ALs_A=[0,1,2]的行ID列表,以及BLs_B=[4,2,6]的列ID列表。我想要一个列表的结果,表示为Ls,这样:

Ls[0] = A[0,:] * B[:,4]
Ls[1] = A[1,:] * B[:,2]
Ls[2] = A[2,:] * B[:,6]

我怎样才能做到这一点?

谢谢大家的帮助!

1 个答案:

答案 0 :(得分:1)

您可以使用tf.gather执行以下操作:

import tensorflow as tf
a=tf.constant([[1,2,3],[4,5,6],[7,8,9]])
b=tf.constant([[1,0,1],[1,0,2],[3,3,-1]])

#taking rows 0,1 from a, and columns 0,2 from b
ind_a=tf.constant([0,1])
ind_b=tf.constant([0,2])

r_a=tf.gather(a,ind_a)

#tf.gather access the rows, so we use it together with tf.transpose to access the columns
r_b=tf.transpose(tf.gather(tf.transpose(b),ind_b))

# the diagonal elements of the multiplication
res=tf.diag_part(tf.matmul(r_a,r_b))
sess=tf.InteractiveSession()
print(r_a.eval())
print(r_b.eval())
print(res.eval())

打印

#r_a
[[1 2 3]
 [4 5 6]]

#r_b
[[ 1  1]
 [ 1  2]
 [ 3 -1]]

#result
[12  8]