为什么tf.layers.Flatten不会改变形状?

时间:2019-04-03 14:19:58

标签: tensorflow

我试图了解tf.layers.Flatten的工作原理。 运行会话后,形状仍然保持(5,4)。为什么不改变?

import numpy as np
import tensorflow as tf

X = tf.constant(np.array([[1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15], [4, 8, 12, 16], [17, 18, 19, 20]]))
flatten_layer = tf.layers.Flatten()
Y = flatten_layer(X)

sess = tf.Session()
result = sess.run(Y)

print(result.shape)

1 个答案:

答案 0 :(得分:0)

因为这是一个图层类。它希望其中一个尺寸(默认情况下是第一个尺寸)为批量大小。因此它不适用于形状(5,4)。

X = tf.constant(np.random.rand(1,5,4))
flatten_layer = tf.layers.Flatten()
Y = tf.layers.flatten(X)

sess = tf.Session()
result = sess.run(Y)
print(X.shape)  #(1,5,4)
print(result.shape) #(1,20)

对于张量,您可以使用tf.reshape