Tensorflow将3D批量张量与2D权重相乘

时间:2018-02-28 02:22:43

标签: tensorflow matrix-multiplication broadcast

我有两个张力,形状如下所示,

batch.shape = [?, 5, 4]
weight.shape = [3, 5]

通过将权重乘以批处理中的每个元素,我想得到

result.shape = [?, 3, 4]

实现这一目标的最有效方法是什么?

2 个答案:

答案 0 :(得分:1)

试试这个:

newbatch = tf.transpose(batch,[1,0,2])
newbatch = tf.reshape(newbatch,[5,-1])
result = tf.matmul(weight,newbatch)
result = tf.reshape(result,[3,-1,4])
result = tf.transpose(result, [1,0,2])

或更紧凑:

newbatch = tf.reshape(tf.transpose(batch,[1,0,2]),[5,-1])
result = tf.transpose(tf.reshape(tf.matmul(weight,newbatch),[3,-1,4]), [1,0,2])

答案 1 :(得分:0)

尝试一下:

tf.einsum("ijk,aj-> iak",batch,weight)

任意维度Refer this for more information的张量之间的广义收缩