Tensorflow会覆盖Tensor
类,including __lt__
, __ge__
等多个运算符。
然而,__eq__
seems to be conspicuously absent的实施:
ops.Tensor._override_operator("__lt__", gen_math_ops.less)
ops.Tensor._override_operator("__le__", gen_math_ops.less_equal)
ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
为什么张量流的张量的==
行为与numpy数组的行为不同?
代码示例:
a = tf.constant([1,2])
b = tf.constant([3,4])
a == b
>>> False
a < b
>>> <tf.Tensor 'Less:0' shape=(2,) dtype=bool>
另一方面,有了numpy:
a = np.asarray([1,2])
b = np.asarray([3, 4])
a == b
>>> array([False, False], dtype=bool)
答案 0 :(得分:1)
张量 实施__eq__
,但the implementation only tests for identity。我找到this GitHub issue,这解释了为什么张量测试身份,而不是广播:
这可能是一个事实的复杂因素,张量可以用作词典中的键,我相信使用
==
来查找具有相同哈希值的匹配对象
评论者是正确的;如果__eq__
被重载到广播,则您无法使用张量作为字典中的键。定义__hash__
方法的对象(如果要将此类对象用作字典中的键,则需要),必须为两个相等的对象生成相同的哈希值;请参阅__hash__
method:
唯一需要的属性是比较相等的对象具有相同的哈希值
但广播会产生真实的&#39;具有不同哈希值的对象的张量对象。
(推测__eq__
会破坏布尔测试是错误的;布尔测试使用__bool__
,张量确实实现了。
如果您需要在张量上进行元素相等的测试,可以使用tf.equal()
和tf.not_equal()
函数。