为什么Tensorflow不会覆盖__eq__?

时间:2017-10-17 07:45:46

标签: python numpy tensorflow

Tensorflow会覆盖Tensor类,including __lt__, __ge__等多个运算符。

然而,__eq__ seems to be conspicuously absent的实施:

ops.Tensor._override_operator("__lt__", gen_math_ops.less)
ops.Tensor._override_operator("__le__", gen_math_ops.less_equal)
ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)

为什么张量流的张量的==行为与numpy数组的行为不同?

代码示例:

a = tf.constant([1,2])
b = tf.constant([3,4])
a == b
>>> False
a < b
>>> <tf.Tensor 'Less:0' shape=(2,) dtype=bool>
另一方面,有了numpy:

a = np.asarray([1,2])
b = np.asarray([3, 4])
a == b
>>> array([False, False], dtype=bool)

1 个答案:

答案 0 :(得分:1)

张量 实施__eq__,但the implementation only tests for identity。我找到this GitHub issue,这解释了为什么张量测试身份,而不是广播:

  

这可能是一个事实的复杂因素,张量可以用作词典中的键,我相信使用==来查找具有相同哈希值的匹配对象

评论者是正确的;如果__eq__被重载到广播,则您无法使用张量作为字典中的键。定义__hash__方法的对象(如果要将此类对象用作字典中的键,则需要),必须为两个相等的对象生成相同的哈希值;请参阅__hash__ method

  

唯一需要的属性是比较相等的对象具有相同的哈希值

但广播会产生真实的&#39;具有不同哈希值的对象的张量对象。

(推测__eq__会破坏布尔测试是错误的;布尔测试使用__bool__,张量确实实现了。

如果您需要在张量上进行元素相等的测试,可以使用tf.equal()tf.not_equal()函数。