使用tf.reshape替换tf.squeeze

时间:2017-12-05 16:50:23

标签: tensorflow reshape

我需要使用tensorflow训练移动网络。不支持tf.squeeze图层。我可以用tf.reshape替换它吗?

是操作:

tf.squeeze(net, [1, 2], name='squeeze')

与:

相同
tf.reshape(net, [50,1000], name='reshape')

其中net的形状为[50,1,1,1000]。

2 个答案:

答案 0 :(得分:1)

为什么说不支持tf.squeeze?为了从张量中移除1维轴,tf.squeeze是正确的操作。但您也可以使用tf.reshape完成所需的工作,但我建议您使用tf.squeeze

答案 1 :(得分:0)

tf 2.0中,您可以轻松地检查这些操作是否相同。唯一的区别是您可以使用dim == 1删除所有轴而不指定它们。因此,在最后一行中,您可以使用tf.squeeze(x_resh)代替tf.squeeze(x_resh, [1, 2])

size = [2, 3]
tf.random.set_seed(42)
x = tf.random.normal(size)
x
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258,  0.3194337],
       [-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>

x_resh = tf.reshape(x, [2, 1, 1, 3])
x_resh
<tf.Tensor: shape=(2, 1, 1, 3), dtype=float32, numpy=
array([[[[ 0.3274685, -0.8426258,  0.3194337]]],
       [[[-1.4075519, -2.3880599, -1.0392479]]]], dtype=float32)>

tf.reshape(x_resh, [2, 3])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258,  0.3194337],
       [-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>

tf.squeeze(x_resh, [1, 2])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258,  0.3194337],
       [-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>