乘以具有1-d张量的4-d张量

时间:2018-03-10 06:25:19

标签: python tensorflow matrix-multiplication

我的维度为[batch_size, num_rows, num_cols, num_values]的等级4张量和维度为[num_values]的等级1张量。我想计算第四列中的值和我的等级1 Tensor的点积,得到维度为[batch_size, num_rows, num_cols, 1]的等级4 Tensor,然后我可以tf.squeeze到具有维度的Tensor { {1}}。有谁知道我怎么做到这一点?

1 个答案:

答案 0 :(得分:1)

您可以使用tensordotreduce_sum

a = tf.constant(np.random.rand(2, 3, 5, 7))
b = tf.constant(np.random.rand(7))

tf.tensordot(a, b, [-1, -1])  # <tf.Tensor 'Tensordot_1:0' shape=(2, 3, 5) dtype=float64>

tf.reduce_sum(a * b, axis=-1)  # <tf.Tensor 'Tensordot_1:0' shape=(2, 3, 5) dtype=float64>