批量内的Tensorflow乘法广播

时间:2017-01-11 13:41:47

标签: tensorflow

我们知道tf.multiply可以像这样广播:

import tensorflow as tf
import numpy as np
a = tf.Variable(np.arange(12).reshape(3, 4))
b = tf.Variable(np.arange(4))
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(tf.multiply(a, b))

这会给我们

[[0, 1, 4, 9],
 [0, 5, 12, 21],
 [0, 9, 20, 33]]

但我的问题是,如果ab分批进行,我该怎么办?也就是说,

a = tf.Variable(np.arange(24).reshape(2, 3, 4))
b = tf.Variable(np.arange(8).reshape(2, 4))

那么如何在每批中将矢量乘以(广播)到矩阵上得到结果呢?如下面的答案:

[[[0, 1, 4, 9],
  [0, 5, 12, 21],
  [0, 9, 20, 33]],

 [[48, 65, 84, 105],
  [64, 85, 108, 133],
  [80, 105, 132, 161]]]

谢谢!

1 个答案:

答案 0 :(得分:1)

广播首先在左侧添加单身尺寸,直到匹配等级。在第一种情况下,添加批量维度。但在第二种情况下,您已经拥有批量维度,因此您需要在第二个位置手动插入单个维度:

a = tf.reshape(tf.range(24), (2, 3, 4))
b = tf.reshape(tf.range(8), (2, 4))
sess.run(tf.mul(a, tf.expand_dims(b, 1)))