我正在使用Tensorflow编码模型。我的条件语句的一部分,例如:
new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: tf.constant([1, src_shape[0]]))
和src_shape
是tf.shape()
的结果。
它报告TypeError: List of Tensors when single Tensor expected
。我知道是因为tf.constant([1, src_shape[0]])
是张量列表,但是我不知道如何以合法方式实现代码。
我已尝试删除tf.constant()
,例如
new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: [1, src_shape[0]])
但它报告ValueError: Incompatible return values of true_fn and false_fn: The two structures don't have the same nested structure.
答案 0 :(得分:2)
一种方法是使用tf.stack,它将一个等级R张量列表堆叠到一个等级(R + 1)张量中。
lambda: tf.stack([1, src_shape[0]], axis=0)
另一种解决方案是使用正确的tf.reshape命令使用tf.concat。
答案 1 :(得分:1)
我已尝试tf.convert_to_tensor([1, src_shape[0]])
起作用。这是另一种解决方案。