我有一个张量a
,如果它等于1,我想压缩它的第一维。
我尝试过
import tensorflow as tf
a = tf.zeros((2, 3))
tf.cond(tf.equal(a.shape[0], 1), lambda: tf.squeeze(a, axis=0), lambda: a)
,但是它不起作用,因为tf.cond
在评估条件之前先执行true_fn
和false_fn
,如果a
的第一维不是{{ 1}}:
1
答案 0 :(得分:2)
这应该有效:
dims = tf.cond(tf.equal(a.shape[0], 1), lambda: tf.shape(a)[1:], lambda: tf.shape(a))
reshaped = tf.reshape(a, dims)
我得到的不是{挤压},而是a
中的tf.cond
形状,如果第一个轴为1
,我将得到没有第一个轴或完整形状的形状除此以外。然后,我将a
重塑为所获得的形状。
答案 1 :(得分:1)
您可以使用
tf.cond(tf.equal(a.shape[0], 1), lambda: a[0], lambda: a)
我们没有压缩,而是简单地索引到第一个维度并获取那里的唯一条目。在某些情况下,这仍然可能会崩溃,但应该可以解决您的特定问题。