(已编辑 w.r.t. @ quirk'
我正在网上阅读一些tensorflow代码并看到了这些陈述:
threshold = tf.select(input > RLSA_THRESHOLD, positive, negative)
来源:https://github.com/Raverss/tensorflow-RLSA-NMS/blob/master/source.py#L31
positive
张量仅为1
,negative
也与0
的大小相同,输入为热图(/ tensor)大小相同(所有类型tf.float32
)。
如果没有tf.cast(input > RLSA_THRESHOLD, tf.float32)
表达式的具体原因,我可以假设作者会使用tf.select(...)
,代码片段似乎相当先进。特别是因为这样就不需要变量positive
和negative
,并且可以节省内存,因为它们只是存储0
和1
的昂贵冗余方式。 / p>
上述tf.select(...)
表达式是否等同于tf.cast(input > RLSA_THRESHOLD, tf.float32)
?如果没有,为什么不呢?
注意:我经常使用Keras,如果我在这里触及一些非常微不足道的事情,我很抱歉。
答案 0 :(得分:3)
嗯,RTD(阅读文档)!
tf.select根据positive
张量中元素的 boolness 从negative
或condition
张量中选择元素。
tf.select(condition, t, e, name=None)
根据条件从t或e中选择元素。
t和e张量都必须具有相同的形状,输出也会有这种形状。
(来自官方文档。)
所以在你的情况下:
threshold = tf.select(input > RLSA_THRESHOLD, positive, negative)
input > RLSA_THRESHOLD
将是bool
或逻辑值的张量(符号0
或1
),这将有助于从positive
中选择一个值向量或negative
向量。
例如,假设您的RLSA_THRESHOLD
为0.5,而您的input
向量是实际连续值的4维向量,范围从0到1.您的positive
和{{ 1}}向量基本上分别是negative
和[1, 1, 1, 1]
。 [0, 0, 0, 0]
为input
。
[0.8, 0.2, 0.5, 0.6]
将为threshold
。
注意: [1, 0, 0, 1]
和positive
可以是任何类型的张量,只要维度与negative
张量一致即可。如果condition
和positive
分别为negative
和[2, 4, 6, 8]
,那么[1, 3, 5, 7]
就会threshold
。
如果
[2, 3, 5, 8]
没有具体原因,我可以假设作者会使用input > RLSA_THRESHOLD
,代码片段似乎相当先进。
这是有充分理由的。 tf.select
只会返回一个逻辑(布尔)值的张量。逻辑值不与数值混合良好。您不能将它们用于任何实际的数值计算。如果input > RLSA_THRESHOLD
和/或positive
张量具有实际价值,那么您可能需要negative
张量也具有实际值,以防您计划进一步使用它们。
threshold
是否等同于tf.select
?如果没有,为什么不呢?
不,他们不是。一个是函数,另一个是张量。
我将怀疑你,并假设你想问:
input > RLSA_THRESHOLD
是否等同于threshold
?如果没有,为什么不呢?
不,他们不是。如上所述,input > RLSA_THRESHOLD
是数据类型为input > RLSA_THRESHOLD
的逻辑张量。另一方面,bool
是一个与threshold
和positive
具有相同数据类型的张量。
注意:您始终可以使用casting中提供的任何tensorflow方法将逻辑张量转换为数字(或任何其他支持的数据类型)张量。
答案 1 :(得分:2)
你能理解它的最好方法是亲自尝试:
In [86]: s = tf.InteractiveSession()
In [87]: inputs = tf.random_uniform([10], 0., 1.)
In [88]: positives = tf.ones([10])
In [89]: negatives = tf.zeros([10])
In [90]: s.run([inputs, tf.select(inputs > .5, positives, negatives)])
Out[90]:
[array([ 0.13187623, 0.77344072, 0.29853749, 0.29245567, 0.53489852,
0.34861541, 0.15090156, 0.40595055, 0.34910154, 0.24349082], dtype=float32),
array([ 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.], dtype=float32)]
对于每个值>张量0.5
中的inputs
您将在同一索引处获得1.
,否则值为0.
。
inputs > .5
的结果是一个布尔值的张量(True
表示满足条件的值,否则为False
。
In [92]: s.run(inputs > .5)
Out[92]: array([ True, False, True, True, True, True, True, True, False, True], dtype=bool)