TensorFlow - 分裂和挤压

时间:2017-02-28 19:44:51

标签: python numpy tensorflow

我是TensorFlow的新手,我正在格式化一些数据以输入回归神经网络。我的数据由输入占位符x的3d张量给出。我想沿着第三维分割x,为此我注意到{注意n_timesteps对应于第三维x的长度):

# Split the previous 3d tensor to get a list of 'n_timesteps' 2d tensors of
# shape (batch_size, features_dimension)
x = tf.split (x, n_timesteps, axis = 2)

虽然,正如我尝试使用numpy

x = np.split (x, n_timesteps, axis = 2)

如果x是3d ndarray,那么np.split将返回尺寸为3的n_timesteps数组列表,这样第三维就是单身。使用numpy我知道我可以使用np.squeeze和列表解析轻松解决此问题,以删除单例维度:

x = [np.squeeze(a, axis=2) for a in np.split(x, n_timesteps, axis=2)]

但我怎么能在TF上做同样的事情?

2 个答案:

答案 0 :(得分:3)

您可能正在寻找tf.unstack操作:

x = tf.unstack(x, axis=2)

答案 1 :(得分:0)

尝试使用Tensorflow的挤压功能(tf.squeeze)和Tensorflow的扫描功能(tf.scan)而不是列表理解。

tf.scan(lambda a, x_i: tf.squeeze(x_i, [2]), x, initializer=tf.constant(0, shape=[n_dim0, n_dim1]))