是否有一个Tensorflow函数,例如带有默认值的where?

时间:2019-07-01 14:55:01

标签: python tensorflow

我有这两个张量:

s = tf.constant([[1], [2], [3], [4], [5], [6], [7])
r = tf.constant([[1, 1], [2, 5], [1, 2], [8, 7], [7, 5], [6, 6], [8, 6])

考虑到两个张量的第一维必须相等,比方说,n(在这种情况下为7)。我想知道在Tensorflow中是否有一个函数,给定这两个张量,返回一个维Tensor of Dimensions(n,2),这样每行的第一个元素是该行的索引,第二个元素是,如果s的第i个元素出现在r [i]中,则s [i]的索引在r [i]中,否则为默认值。我想举个例子会更容易:

>>> fancy_function(r, s, default_value=-1).eval()
[[0, 0], [1, 0], [2, -1], [3, -1], [4, 1], [5, 0], [6, -1]]

实际上,如果允许随机选择,它可能会返回另一件事:

>>> fancy_function(r, s, default_value=-1, random_choice=True).eval()
[[0, 1], [1, 0], [2, -1], [3, -1], [4, 1], [5, 1], [6, -1]]

元素0和5中的第二个元素可以更改,因为r [0] [0]和r [0] [1]都等于1,并且r [5] [0]和r [5] [1 ]等于6。

我认为它与where函数相似。

事实上,在哪里,我会得到这样的东西:

>>> tf.where(tf.equal(s, r)).eval()
[[0, 0], [0, 1], [1, 0], [4, 1], [5, 0], [5, 1]]

但是,一方面,它会重复索引为0和5的元素(我想以不确定的方式选择它们),并且不包含默认值。

0 个答案:

没有答案