Tensor A is [M X 2N X N]
Tensor B is [M X N]
我需要将两个张量相乘得到:
Tensor C [M X 2N X N].
以下是一个例子:
M= 2, N = 2
A: [[[1,2]
[1,2]
[1,2]
[1,2] ]
[[2,2]
[2,2]
[2,2]
[2,2] ]]
B = [[3,3]
[2,2]]
C: [[[3,6]
[3,6]
[3,6]
[3,6] ]
[[4,4]
[4,4]
[4,4]
[4,4] ]]
不确定如何实现这一目标。除了解决方案,有人可以解释广播在这种情况下的确切运作方式。
答案 0 :(得分:0)
看起来你想要元素乘法,其中张量B在该维度中重复2N次。我使用tf.concat手动播放广播(2,[tf.reshape(b,[m,1,n])for _ in xrange(2 * n)]。