使用tensorflow的tensordot混淆结果

时间:2017-10-02 09:11:17

标签: python tensorflow

在使用tensorflow' tf.tensordot时,我遇到了一些奇怪的结果。运行以下代码块

import tensorflow as tf
import numpy as np
a = np.arange(6, dtype=np.int32).reshape(3,2)
b = np.arange(1,7, dtype=np.int32).reshape(2,3)
sess = tf.Session()
print(sess.run(tf.tensordot(a, b, [[0,1],[0,1]])))
print(sess.run(tf.tensordot(a, b, [[0,1],[1,0]])))
print(sess.run(tf.tensordot(a, b, [[1,0],[0,1]])))
print(sess.run(tf.tensordot(a, b, [[1,0],[1,0]])))

产生

70
65
65
60

我无法弄清楚这里发生了什么收缩。另一件有趣的事情是,尝试使用numpy的tensordot会为几个尝试的轴返回一个错误。

1 个答案:

答案 0 :(得分:0)

您已在tensorflow中发现一个错误。

根据tf.tensordot的文档

  

a_axes[i]中所有a的{​​{1}}的{​​{1}}轴必须与b_axes[i]的{​​{1}}轴具有相同的尺寸。

例如,b应该返回了错误,因为irange(0, len(a_axes)),而tf.tensordot(a, b, [[0,1],[0,1]]))a。但这不是-那是错误。

相反,它继续进行,就像张量是兼容的。如果3x2b兼容,则2x3将是一个简单的点积。 a内部所做的是基本上将btf.tensordot(a, b, [[0,1],[0,1]]))展平并计算点积。

在您的情况下,tf.tensordota具有相同数量的元素,因此尽管它们的形状不兼容,但点积的计算还是成功的。

您可以向TF小组here提交错误。