使用tf.cond()

时间:2019-04-23 06:36:10

标签: python tensorflow

我正在使用Tensorflow编码模型。我的条件语句的一部分,例如:

new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: tf.constant([1, src_shape[0]]))

src_shapetf.shape()的结果。

它报告TypeError: List of Tensors when single Tensor expected。我知道是因为tf.constant([1, src_shape[0]])是张量列表,但是我不知道如何以合法方式实现代码。

我已尝试删除tf.constant(),例如

new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: [1, src_shape[0]])

但它报告ValueError: Incompatible return values of true_fn and false_fn: The two structures don't have the same nested structure.

2 个答案:

答案 0 :(得分:2)

一种方法是使用tf.stack,它将一个等级R张量列表堆叠到一个等级(R + 1)张量中。

lambda: tf.stack([1, src_shape[0]], axis=0)

另一种解决方案是使用正确的tf.reshape命令使用tf.concat。

答案 1 :(得分:1)

我已尝试tf.convert_to_tensor([1, src_shape[0]])起作用。这是另一种解决方案。