tf.cond() 返回形状未知的张量

时间:2021-07-14 14:12:50

标签: keras conditional-statements tensorflow2.0

下面是我传递给 keras Lambda 层的函数。

tf.cond() 的输出出现问题。它返回 <unknown> 的形状。输入张量 (t) 和恒权张量分别具有 (None,6)(6,) 的形状。当我在 tf.cond() 之外添加这两个时,我会得到一个形状为 (None,6) 的张量,这正是我需要的。但是,当从 tf.cond() 内返回相同的添加操作时,我得到形状为 <unknown> 的张量。

当此操作通过 tf.cond() 时会发生什么变化。

def class_segmentation(t):

        class_segments = tf.constant([0,0,1,1,2,2])

        a = tf.math.segment_mean(t, class_segments, name=None)

        b = tf.math.argmax(a)
  
        left_weights = tf.constant([1.0,1.0,0.0,0.0,0.0,0.0])
        middle_weights = tf.constant([0.0,0.0,1.0,1.0,0.0,0.0])
        right_weights = tf.constant([0.0,0.0,0.0,0.0,1.0,1.0])
        zero_weights = tf.constant([0.0,0.0,0.0,0.0,0.0,0.0])

        c = tf.cond(tf.math.equal(b,0), lambda: tf.math.add(t, left_weights), lambda: zero_weights)
        d = tf.cond(tf.math.equal(b,1), lambda: tf.math.add(t, middle_weights ), lambda: zero_weights)
        e = tf.cond(tf.math.equal(b,2), lambda: tf.math.add(t, right_weights), lambda: zero_weights)

        f = tf.math.add_n([c,d,e])
        print("Tensor shape: ", f.shape) # returns "Unknown"
        
        return f

1 个答案:

答案 0 :(得分:0)

您的代码存在一些问题。

  1. tf.math.segment_mean() 期望 class_segments 与输入 t 的第一个维度具有相同的形状。因此 None 必须等于 6 才能运行您的代码。这很可能是导致您获得 unknown 形状的原因 - 因为张量的形状取决于 None,而 a = tf.math.segment_mean(tf.transpose(t), class_segments) 是在运行时确定的。您可以为要运行的代码应用转换(不确定这是否是您要实现的目标),例如。
true_fn
  1. tf.cond() 中,false_fntrue_fn 必须返回相同形状的张量。在您的情况下,(None, 6) 返回 false_fn,因为 broadcasting(6,) 返回形状为 b = tf.math.argmax(tf.math.segment_mean(tf.transpose(t), class_segments), 0) 的张量。
  2. tf.cond() 中的谓词必须降为 0 级。例如,如果您要申请 b 那么 (None) 的形状将是 pred 并且 tf.cond() 中的谓词 (uiop:native-namestring "~/Music/[Video] performance.mp4") ==> The pathname #P"~/Music/[Video] performance.mp4" does not have a native namestring because of the :NAME component #<SB-IMPL::PATTERN (:CHARACTER-SET . "Video") " performance">. [Condition of type SB-KERNEL:NO-NATIVE-NAMESTRING-ERROR] 将是 broadcasted 到相同的形状(这将引发错误)。< /li>

如果不知道您想要获得进一步的帮助是不可能的。