我是TensorFlow的新手,我正在格式化一些数据以输入回归神经网络。我的数据由输入占位符x
的3d张量给出。我想沿着第三维分割x
,为此我注意到{注意n_timesteps
对应于第三维x
的长度):
# Split the previous 3d tensor to get a list of 'n_timesteps' 2d tensors of
# shape (batch_size, features_dimension)
x = tf.split (x, n_timesteps, axis = 2)
虽然,正如我尝试使用numpy
:
x = np.split (x, n_timesteps, axis = 2)
如果x
是3d ndarray
,那么np.split
将返回尺寸为3的n_timesteps
数组列表,这样第三维就是单身。使用numpy
我知道我可以使用np.squeeze
和列表解析轻松解决此问题,以删除单例维度:
x = [np.squeeze(a, axis=2) for a in np.split(x, n_timesteps, axis=2)]
但我怎么能在TF上做同样的事情?
答案 0 :(得分:3)
您可能正在寻找tf.unstack
操作:
x = tf.unstack(x, axis=2)
答案 1 :(得分:0)
尝试使用Tensorflow的挤压功能(tf.squeeze)和Tensorflow的扫描功能(tf.scan)而不是列表理解。
tf.scan(lambda a, x_i: tf.squeeze(x_i, [2]), x, initializer=tf.constant(0, shape=[n_dim0, n_dim1]))