(tensorflow)批量乘以特征图及其特征均值:[n,h,w,c] * [n,c]

时间:2018-08-17 17:34:36

标签: tensorflow

我如何将形状为[n,h,w,c]的特征图张量与其特征均值[n,c]相乘,而不用分批叠加[n,c]广播?

我正在寻找最快的方法。

不幸的是,tf.multiply不起作用。我有两种方法:(1)平铺[n,c]张量,(2)使用tf.einsum。但是我非常担心使用它们,因为如果我循环使用它们,它们可能会变得很慢。

-----------我刚刚尝试了以下---------

x = ... # [n,h,w,c]
mean = ... # [n,c]
out = x* tf.expand_dims(tf.expand_dims(mean,1),1)

它奏效了。这是正确的解决方案吗?

0 个答案:

没有答案