我正在尝试理解这段代码(来自here),该代码使用两个张量之间的矩阵乘法实现点积注意。具体来说,来自Keras后端的batch_dot()函数用于两个张量都可变的第一张量之间。在这种情况下,batch_dot()的执行方式似乎有所不同,与指定第一个维度时相反。
MWE:
固定的一维尺寸,按预期工作
q = K.ones(shape=(36,8,24))
k = K.ones(shape=(36,8,24))
print(K.batch_dot(q,k,axes=[1,1]))
返回
Tensor("MatMul_8:0", shape=(?, 36, 24, 24), dtype=float32)
和
print(K.batch_dot(q,k,axes=[2,2]))
返回
Tensor("MatMul_9:0", shape=(?, 36, 8, 8), dtype=float32)
但是,将q和k定义如下:
q = Input(shape=(36,8,24))
k = Input(shape=(36,8,24))
print(q)
print(k)
(可变第一维度)
Tensor("input_24:0", shape=(?, 36, 8, 24), dtype=float32)
Tensor("input_25:0", shape=(?, 36, 8, 24), dtype=float32)
batch_dot()操作的输出尺寸是意外的:
K.batch_dot(q,k,axes=[1,1])
<tf.Tensor 'MatMul_11:0' shape=(?, 36, 24, 24) dtype=float32>
K.batch_dot(q,k,axes=[2,2])
<tf.Tensor 'MatMul_12:0' shape=(?, 36, 24, 24) dtype=float32>
根据documentation,axes
的参数指定了在操作过程中删除的尺寸,但是我无法将此定义连接到上面的输出。
?
的参数计算第一个维度(值为axes
)?
答案 0 :(得分:0)
是否将第一个维度(带有值?)计算为参数 的轴数?
是的,它算在内。
事实是,在上面的示例中,Input
层中的第一维是批处理大小,而在K.ones()
中则不是。结果,输入的轴[3,3]等于K.ones()
中的轴[2,2]。在您的代码中,以下两个batch_dot
相等:
q = tf.keras.layers.Input(shape=(36, 8, 24)) # shape =(?, 36,8,24)
k = tf.keras.layers.Input(shape=(36, 8, 24)) # shape =(?, 36,8,24)
print(tf.keras.backend.batch_dot(q, k, axes=[3, 3]))
q = tf.keras.backend.ones(shape=(36, 8, 24))
k = tf.keras.backend.ones(shape=(36, 8, 24))
print(tf.keras.backend.batch_dot(q, k, axes=[2, 2]))
请注意,在K.ones()
中,如果形状是符号形状,则无法返回变量,而将返回动态形状的张量。这是什么意思?请参阅以下示例,以更好地理解:
a = tf.keras.layers.Input(shape=(30,))
c = tf.keras.backend.ones(shape=tf.shape(a))
print(c) # shape=(?, 30)
d = tf.keras.backend.ones(shape=(30, 40))
print(d) # shape=(30,40)
batch_dot()操作的输出尺寸是意外的
K.batch_dot(q,k,axes=[1,1])
<tf.Tensor 'MatMul_11:0' shape=(?, 36, 24, 24) dtype=float32>
K.batch_dot(q,k,axes=[2,2])
<tf.Tensor 'MatMul_12:0' shape=(?, 36, 24, 24) dtype=float32>
为什么在地球上这发生在轴不同的情况下?
要回答此问题,我们应该了解batch_dot
的基本实现。如果输入张量的秩不是2,则我们的batch_dot
表现为tf.matmul
运算,即输入张量之一被共轭转置。结果,当我们的输入张量的秩为3并且我们将轴设置为0或1时,它们计算的结果相同,但是当将轴设置为2时它计算出的结果则不同:
a = np.array([[[1, 2, 3],
[3, 2, 1]]]) # rank 3
b = np.array([[[1, 3, 3],
[2, 2, 0]]]) # rank 3
a = tf.constant(a, dtype=tf.float32)
b = tf.constant(b, dtype=tf.float32)
c = tf.matmul(a, b, adjoint_a=True, adjoint_b=False) # when axes is [0,0] or [1,1]
d = tf.matmul(a, b, adjoint_a=False, adjoint_b=True) # when axes is [2,2]
print(c.shape) # shape=(1,3,3)
print(d.shape) # shape=(1,2,2)
在您的示例中发生了同样的事情:
a = np.array([[[1, 2, 3],
[3, 2, 1]]])
b = np.array([[[1, 3, 3],
[2, 2, 0]]])
q = tf.placeholder(dtype=tf.float32, shape=(None, 2, 3))
k = tf.placeholder(dtype=tf.float32, shape=(None, 2, 3))
res1 = tf.keras.backend.batch_dot(q, k, axes=0)
res2 = tf.keras.backend.batch_dot(q, k, axes=1)
res3 = tf.keras.backend.batch_dot(q, k, axes=2)
with tf.Session() as sess:
feed_dic = {q: a, k: b}
print(sess.run(res1, feed_dict=feed_dic))
print(20 * '-')
print(sess.run(res2, feed_dict=feed_dic))
print(20 * '-')
print(sess.run(res3, feed_dict=feed_dic))
答案 1 :(得分:0)
如果您查看https://github.com/tensorflow/tensorflow/blob/a6d8ffae097d0132989ae4688d224121ec6d8f35/tensorflow/python/keras/backend.py#L1437上的源代码,一切将会很清楚
我们可以直接进入line1507
if ndim(x) == 2 and ndim(y) == 2:
if axes[0] == axes[1]:
out = math_ops.reduce_sum(math_ops.multiply(x, y), axes[0])
else:
out = math_ops.reduce_sum(
math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else:
adj_x = None if axes[0] == ndim(x) - 1 else True
adj_y = True if axes[1] == ndim(y) - 1 else None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
如所示,它仅检查adj_x
和adj_y
,而不会将axes
参数传递给math_ops.matmul
方法。这就是当axes
为[1,1]
和[2,2]
时获得相同结果的原因。
我们可以使用以下代码进行验证:
q = K.ones(shape=range(1, 10))
k = K.ones(shape=range(1, 10))
for i in range(10): print(i, K.batch_dot(q,k,axes=[i,i]))
它将打印
0 Tensor("MatMul_7:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
1 Tensor("MatMul_8:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
2 Tensor("MatMul_9:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
3 Tensor("MatMul_10:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
4 Tensor("MatMul_11:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
5 Tensor("MatMul_12:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
6 Tensor("MatMul_13:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
7 Tensor("MatMul_14:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
8 Tensor("MatMul_15:0", shape=(1, 2, 3, 4, 5, 6, 7, 8, 8), dtype=float32)
9 Tensor("MatMul_16:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
除了i
为8时,其他所有结果都返回相同的结果。