使用tf.cond压缩第一维(如果等于1)

时间:2019-08-08 14:13:44

标签: python tensorflow

我有一个张量a,如果它等于1,我想压缩它的第一维。

我尝试过

import tensorflow as tf
a = tf.zeros((2, 3))
tf.cond(tf.equal(a.shape[0], 1), lambda: tf.squeeze(a, axis=0), lambda: a)

,但是它不起作用,因为tf.cond在评估条件之前先执行true_fnfalse_fn,如果a的第一维不是{{ 1}}:

1

2 个答案:

答案 0 :(得分:2)

这应该有效:

dims = tf.cond(tf.equal(a.shape[0], 1), lambda: tf.shape(a)[1:], lambda: tf.shape(a))
reshaped = tf.reshape(a, dims)

我得到的不是{挤压},而是a中的tf.cond形状,如果第一个轴为1,我将得到没有第一个轴或完整形状的形状除此以外。然后,我将a重塑为所获得的形状。

答案 1 :(得分:1)

您可以使用

tf.cond(tf.equal(a.shape[0], 1), lambda: a[0], lambda: a)

我们没有压缩,而是简单地索引到第一个维度并获取那里的唯一条目。在某些情况下,这仍然可能会崩溃,但应该可以解决您的特定问题。