tf.select

时间:2017-01-06 12:27:51

标签: python tensorflow

已编辑 w.r.t. @ quirk'

我正在网上阅读一些tensorflow代码并看到了这些陈述:

threshold = tf.select(input > RLSA_THRESHOLD, positive, negative)

来源:https://github.com/Raverss/tensorflow-RLSA-NMS/blob/master/source.py#L31

positive张量仅为1negative也与0的大小相同,输入为热图(/ tensor)大小相同(所有类型tf.float32)。

如果没有tf.cast(input > RLSA_THRESHOLD, tf.float32)表达式的具体原因,我可以假设作者会使用tf.select(...),代码片段似乎相当先进。特别是因为这样就不需要变量positivenegative,并且可以节省内存,因为它们只是存储01的昂贵冗余方式。 / p>

上述tf.select(...)表达式是否等同于tf.cast(input > RLSA_THRESHOLD, tf.float32)?如果没有,为什么不呢?

注意:我经常使用Keras,如果我在这里触及一些非常微不足道的事情,我很抱歉。

2 个答案:

答案 0 :(得分:3)

嗯,RTD(阅读文档)!

tf.select根据positive张量中元素的 boolness negativecondition张量中选择元素。

  

tf.select(condition, t, e, name=None)
  根据条件从t或e中选择元素。
  t和e张量都必须具有相同的形状,输出也会有这种形状。

(来自官方文档。)

所以在你的情况下:

threshold = tf.select(input > RLSA_THRESHOLD, positive, negative)

input > RLSA_THRESHOLD将是bool或逻辑值的张量(符号01),这将有助于从positive中选择一个值向量或negative向量。

例如,假设您的RLSA_THRESHOLD为0.5,而您的input向量是实际连续值的4维向量,范围从0到1.您的positive和{{ 1}}向量基本上分别是negative[1, 1, 1, 1][0, 0, 0, 0]input

[0.8, 0.2, 0.5, 0.6]将为threshold

注意: [1, 0, 0, 1]positive可以是任何类型的张量,只要维度与negative张量一致即可。如果conditionpositive分别为negative[2, 4, 6, 8],那么[1, 3, 5, 7]就会threshold

  

如果[2, 3, 5, 8]没有具体原因,我可以假设作者会使用input > RLSA_THRESHOLD,代码片段似乎相当先进。

这是有充分理由的。 tf.select只会返回一个逻辑(布尔)值的张量。逻辑值与数值混合良好。您不能将它们用于任何实际的数值计算。如果input > RLSA_THRESHOLD和/或positive张量具有实际价值,那么您可能需要negative张量也具有实际值,以防您计划进一步使用它们。

  

threshold是否等同于tf.select?如果没有,为什么不呢?

不,他们不是。一个是函数,另一个是张量。

我将怀疑你,并假设你想问:

  

input > RLSA_THRESHOLD是否等同于threshold?如果没有,为什么不呢?

不,他们不是。如上所述,input > RLSA_THRESHOLD是数据类型为input > RLSA_THRESHOLD的逻辑张量。另一方面,bool是一个与thresholdpositive具有相同数据类型的张量。

注意:您始终可以使用casting中提供的任何方法将逻辑张量转换为数字(或任何其他支持的数据类型)张量。

答案 1 :(得分:2)

你能理解它的最好方法是亲自尝试:

In [86]: s = tf.InteractiveSession()

In [87]: inputs = tf.random_uniform([10], 0., 1.)

In [88]: positives = tf.ones([10])

In [89]: negatives = tf.zeros([10])    

In [90]: s.run([inputs, tf.select(inputs > .5, positives, negatives)])
Out[90]: 
[array([ 0.13187623,  0.77344072,  0.29853749,  0.29245567,  0.53489852,
         0.34861541,  0.15090156,  0.40595055,  0.34910154,  0.24349082], dtype=float32),
 array([ 0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.], dtype=float32)]

对于每个值>张量0.5中的inputs您将在同一索引处获得1.,否则值为0.

inputs > .5的结果是一个布尔值的张量(True表示满足条件的值,否则为False

In [92]: s.run(inputs > .5)
Out[92]: array([ True, False,  True,  True,  True,  True,  True,  True, False,  True], dtype=bool)