Tensorflow:撤消全局平均池

时间:2018-11-07 13:48:06

标签: python tensorflow

在Tensorflow中,我在网络末尾进行以下全局平均池化:

x_ = tf.reduce_mean(x, axis=[1,2])

我的张量x的形状为(n, h, w, c),其中n是输入的数量,wh对应于宽度和高度尺寸,并且c是通道/过滤器的数量。

在调用x之后,从大小为(n, h, w, c)的张量tf.reduce_mean()开始,得到的张量大小为(n, c)

如何撤销该过程?我该如何进行分池操作?

编辑

这是一个无法正常运行的示例:

import tensorflow as tf
import numpy as np

n, c = 1, 2 
h, w = 2, 2

x = tf.ones([n, h, w, c])
y = tf.reduce_mean(x, axis=[1,2], keepdims=True)
z = tf.reshape(y, [n, 1, 1, c])
u = tf.tile(z, [n, h, w, c])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(x)
    print("x", sess.run(x))
    print("\n")
    print(y)
    print("y", sess.run(y))
    print("\n")
    print(z)
    print("z", sess.run(z))
    print("\n")
    print(u)
    print("u", sess.run(u))

输出为:

Tensor("ones:0", shape=(1, 2, 2, 2), dtype=float32)
x [[[[1. 1.]
   [1. 1.]]

  [[1. 1.]
   [1. 1.]]]]


Tensor("Mean:0", shape=(1, 1, 1, 2), dtype=float32)
y [[[[1. 1.]]]]


Tensor("Reshape:0", shape=(1, 1, 1, 2), dtype=float32)
z [[[[1. 1.]]]]


Tensor("Tile:0", shape=(1, 2, 2, 4), dtype=float32)
u [[[[1. 1. 1. 1.]
   [1. 1. 1. 1.]]

  [[1. 1. 1. 1.]
   [1. 1. 1. 1.]]]]

1 个答案:

答案 0 :(得分:-1)

您可以使用tf.reshapetf.tile进行分拆操作。

x = tf.random_uniform([n, c])
y = tf.reshape(x, [n, 1, 1, c])
z = tf.tile(y, [1, h, w, 1])

在调用x(n, c)之后,从大小为tf.reshape的张量tf.tile开始,结果张量z的大小为(n, h, w, c)。 / p>