在使用tensorflow' tf.tensordot
时,我遇到了一些奇怪的结果。运行以下代码块
import tensorflow as tf
import numpy as np
a = np.arange(6, dtype=np.int32).reshape(3,2)
b = np.arange(1,7, dtype=np.int32).reshape(2,3)
sess = tf.Session()
print(sess.run(tf.tensordot(a, b, [[0,1],[0,1]])))
print(sess.run(tf.tensordot(a, b, [[0,1],[1,0]])))
print(sess.run(tf.tensordot(a, b, [[1,0],[0,1]])))
print(sess.run(tf.tensordot(a, b, [[1,0],[1,0]])))
产生
70
65
65
60
我无法弄清楚这里发生了什么收缩。另一件有趣的事情是,尝试使用numpy的tensordot会为几个尝试的轴返回一个错误。
答案 0 :(得分:0)
您已在tensorflow中发现一个错误。
根据tf.tensordot
的文档
a_axes[i]
中所有a
的{{1}}的{{1}}轴必须与b_axes[i]
的{{1}}轴具有相同的尺寸。>
例如,b
应该返回了错误,因为i
是range(0, len(a_axes))
,而tf.tensordot(a, b, [[0,1],[0,1]]))
是a
。但这不是-那是错误。
相反,它继续进行,就像张量是兼容的。如果3x2
和b
兼容,则2x3
将是一个简单的点积。 a
内部所做的是基本上将b
和tf.tensordot(a, b, [[0,1],[0,1]]))
展平并计算点积。
在您的情况下,tf.tensordot
和a
具有相同数量的元素,因此尽管它们的形状不兼容,但点积的计算还是成功的。
您可以向TF小组here提交错误。