使用Tensorflow的批量矩阵乘法

时间:2016-09-07 16:43:05

标签: numpy tensorflow

Tensor A is [M X 2N X N]
Tensor B is [M X N]

我需要将两个张量相乘得到:

Tensor C [M X 2N X N]. 

以下是一个例子:

M= 2, N = 2

A: [[[1,2]
     [1,2]
     [1,2]
     [1,2] ]

    [[2,2]
     [2,2]
     [2,2]
     [2,2] ]]

 B = [[3,3]
      [2,2]]

 C: [[[3,6]
      [3,6]
      [3,6]
      [3,6] ]

      [[4,4]
       [4,4]
       [4,4]
       [4,4] ]]

不确定如何实现这一目标。除了解决方案,有人可以解释广播在这种情况下的确切运作方式。

1 个答案:

答案 0 :(得分:0)

看起来你想要元素乘法,其中张量B在该维度中重复2N次。我使用tf.concat手动播放广播(2,[tf.reshape(b,[m,1,n])for _ in xrange(2 * n)]。