tensorflow 1.13如何使用tf.searchsort安全?

时间:2019-09-18 06:10:38

标签: python tensorflow

使用代码

tf.searchsorted(input, input2)

我遇到了第一个错误

  

InvalidArgumentError(请参阅上面的回溯):重塑无法推断   除非所有指定的输入,否则缺少张量为空的输入大小   大小不为零

它也使我想起

的第3459行
  

tensorflow / python / ops / array_ops.py

     

searchsorted

得到

sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])

但是当张量形状包含0个维度input.shape=(0,)时,它将返回错误。参见here

我想使用tf.searchsorted之前检查张量形状,并且我知道尺寸为 None

所以我用

if not tf.equal(input.shape[0], 0):
  tf.searchsorted(input, input2)

然后我遇到了第二个错误,我知道tf.equal将返回布尔张量,不能像bool那样使用。但是我不知道如何解决我的第一个错误。

  

ValueError:试图将'x'转换为张量,但失败。错误:无法   将未知尺寸转换为张量:?

我的问题是,如果第一个错误按tf.searchsort维度触发,如何使用0安全

1 个答案:

答案 0 :(得分:1)

尽管没有明确说明,但它暗示tf.searchsorted不适用于第一个参数中的空序列。

不过,您可以使用tf.cond来表示类似“如果不是空序列,请使用searchsorted,否则返回全零”(或您希望返回的任何值):

tf.cond(tf.not_equal(tf.size(input), 0),
        lambda: tf.searchsorted(input, input2),
        lambda: tf.zeros_like(input2, dtype=tf.int32))