如何结合tf.map_fn和tf.split

时间:2019-03-06 13:45:48

标签: python tensorflow

所以我想要的东西的伪代码是:

splitted_outputs = [tf.split(output, rate, axis=0) for output in outputs]

其中输出是形状为(512,?,128)的张量,而splitted_outputs是具有3维的张量或张量列表的列表。所以我可以迭代这种张量张量流。

我尝试使用tf.map_fn

splitted_outputs = tf.map_fn(
    lambda output: tf.split(output, rate, axis=0),
    outputs,
    dtype=list
)

但这是不可能的,因为listdtype上不合法。

1 个答案:

答案 0 :(得分:1)

您可以在outputs上使用tf.unstack来获取“子目录”列表,然后在其中的每一个上使用tf.split

splitted_outputs = [tf.split(output, rate, axis=0) for output in tf.unstack(outputs, axis=0)]

请注意,tf.unstack仅可以在已知给定axis的大小的情况下使用,否则,您需要提供一个num参数。