Tensorflow多标量乘法

时间:2018-02-24 14:46:06

标签: python tensorflow

我有一个带有[batch_size,x,y]的3d张量和一个向量[batch_size]

我希望标量乘以第i个矩阵[x,y]和给定向量的第i个条目。

Tensorflow中是否有内置功能,还是必须使用tf.while_loop

2 个答案:

答案 0 :(得分:2)

你可以通过广播来做到这一点。你需要先重塑矢量。

a = tf.constant([[[1,1],[2,2]],[[3,3],[4,4]]])
b = tf.constant([2,3])
c = tf.reshape(b, [-1,1,1])
d = a * c

>>> sess.run(d)
  array([[[ 2,  2],
    [ 4,  4]],

   [[ 9,  9],
    [12, 12]]], dtype=int32)

答案 1 :(得分:0)

如果有内置函数我不会,但你也不需要使用while循环。你可以做基本的数组操作。 e.g:

a=tf.random_uniform([3,5,8])
b=tf.random_uniform([3])
c=tf.expand_dims(tf.expand_dims(b, -1),1)
c=tf.tile(c,[1,5,8])
d=tf.multiply(a,c)
sess=tf.Session()
sess.run([a,b,c,d])

它应该有用。