扩展维度并在tensorflow中复制数据

时间:2018-06-12 01:34:29

标签: python tensorflow

我有一个大小为BxHxWx3的张量input和另一个大小为Bx3的张量params。这是B是批量大小。我想将params变成一个大小为BxHxWx3的张量?这样我就可以增加两个张量。关于我应该怎么做的任何建议? (在高层次上,我想要做的是将一组图像中的每个像素乘以为每个通道定义的值)

2 个答案:

答案 0 :(得分:1)

<强> 1。回答您的第一个问题

您可以使用tf.expand_dimstf.tile

的组合
input_shape = tf.shape(input)
mod_params = params.expand_dims(1) # shape is [Bx1x3]
mod_params = mod_params.expand_dims(2) # shape is [Bx1x1x3]
mod_params = tf.tile( \
                mod_params, \
                [1, input_shape[1], input_shape[2], 1] \
             ) # shape is [BxHxWx3]

<强> 2。要执行最终结果,......

......你可以执行

ret = tf.multiply(input, mod_params)

...或者,你也可以使用张量流的广播能力(在tf.transpose的帮助下)

ret = tf.multiply(
         tf.transpose(input, perm=[2,1,0,3]), \
         params \
      ) # shape: [WxHxBx3]
ret = tf.transpose(ret, perm=[2,1,0,3]) # shape: [BxHxWx3]

答案 1 :(得分:0)

不要平铺,使用broadcastingT*tf.reshape(params, [-1,1, 1, tf.shape(n_params)[1]])

T = tf.random_normal((5,2,3,3))
params = tf.random_normal((5,3))

out = T*tf.reshape(params, [-1,1, 1, tf.shape(n_params)[1]])

with tf.Session() as sess:
   print(sess.run(out).shape)
   #(5, 2, 3, 3)