所以我想要的东西的伪代码是:
splitted_outputs = [tf.split(output, rate, axis=0) for output in outputs]
其中输出是形状为(512,?,128)的张量,而splitted_outputs是具有3维的张量或张量列表的列表。所以我可以迭代这种张量张量流。
我尝试使用tf.map_fn
:
splitted_outputs = tf.map_fn(
lambda output: tf.split(output, rate, axis=0),
outputs,
dtype=list
)
但这是不可能的,因为list
在dtype
上不合法。
答案 0 :(得分:1)
您可以在outputs
上使用tf.unstack
来获取“子目录”列表,然后在其中的每一个上使用tf.split
:
splitted_outputs = [tf.split(output, rate, axis=0) for output in tf.unstack(outputs, axis=0)]
请注意,tf.unstack
仅可以在已知给定axis
的大小的情况下使用,否则,您需要提供一个num
参数。