使用Tensorflow后端了解Keras中的batch_dot()

时间:2019-01-06 00:58:46

标签: tensorflow keras matrix-multiplication

我正在尝试理解这段代码(来自here),该代码使用两个张量之间的矩阵乘法实现点积注意。具体来说,来自Keras后端的batch_dot()函数用于两个张量都可变的第一张量之间。在这种情况下,batch_dot()的执行方式似乎有所不同,与指定第一个维度时相反。

MWE:

固定的一维尺寸,按预期工作

q = K.ones(shape=(36,8,24))
k = K.ones(shape=(36,8,24))
print(K.batch_dot(q,k,axes=[1,1]))

返回

Tensor("MatMul_8:0", shape=(?, 36, 24, 24), dtype=float32)

print(K.batch_dot(q,k,axes=[2,2]))

返回

Tensor("MatMul_9:0", shape=(?, 36, 8, 8), dtype=float32)

但是,将q和k定义如下:

q = Input(shape=(36,8,24))
k = Input(shape=(36,8,24))
print(q)
print(k)

(可变第一维度)

Tensor("input_24:0", shape=(?, 36, 8, 24), dtype=float32)
Tensor("input_25:0", shape=(?, 36, 8, 24), dtype=float32)

batch_dot()操作的输出尺寸是意外的:

K.batch_dot(q,k,axes=[1,1])
<tf.Tensor 'MatMul_11:0' shape=(?, 36, 24, 24) dtype=float32>
K.batch_dot(q,k,axes=[2,2])
<tf.Tensor 'MatMul_12:0' shape=(?, 36, 24, 24) dtype=float32>

根据documentationaxes的参数指定了在操作过程中删除的尺寸,但是我无法将此定义连接到上面的输出。

是否为?的参数计算第一个维度(值为axes)?

2 个答案:

答案 0 :(得分:0)

  

是否将第一个维度(带有值?)计算为参数   的轴数?

是的,它算在内。

事实是,在上面的示例中,Input层中的第一维是批处理大小,而在K.ones()中则不是。结果,输入的轴[3,3]等于K.ones()中的轴[2,2]。在您的代码中,以下两个batch_dot相等:

q = tf.keras.layers.Input(shape=(36, 8, 24)) # shape =(?, 36,8,24)
k = tf.keras.layers.Input(shape=(36, 8, 24)) # shape =(?, 36,8,24)
print(tf.keras.backend.batch_dot(q, k, axes=[3, 3]))

q = tf.keras.backend.ones(shape=(36, 8, 24))
k = tf.keras.backend.ones(shape=(36, 8, 24))
print(tf.keras.backend.batch_dot(q, k, axes=[2, 2]))

请注意,在K.ones()中,如果形状是符号形状,则无法返回变量,而将返回动态形状的张量。这是什么意思?请参阅以下示例,以更好地理解:

a = tf.keras.layers.Input(shape=(30,))
c = tf.keras.backend.ones(shape=tf.shape(a))
print(c) # shape=(?, 30)
d = tf.keras.backend.ones(shape=(30, 40))
print(d) # shape=(30,40)

  

batch_dot()操作的输出尺寸是意外的

K.batch_dot(q,k,axes=[1,1])
<tf.Tensor 'MatMul_11:0' shape=(?, 36, 24, 24) dtype=float32>
K.batch_dot(q,k,axes=[2,2])
<tf.Tensor 'MatMul_12:0' shape=(?, 36, 24, 24) dtype=float32>

为什么在地球上这发生在轴不同的情况下?

要回答此问题,我们应该了解batch_dot的基本实现。如果输入张量的秩不是2,则我们的batch_dot表现为tf.matmul运算,即输入张量之一被共轭转置。结果,当我们的输入张量的秩为3并且我们将轴设置为0或1时,它们计算的结果相同,但是当将轴设置为2时它计算出的结果则不同:

a = np.array([[[1, 2, 3],
               [3, 2, 1]]])  # rank 3

b = np.array([[[1, 3, 3],
               [2, 2, 0]]])  # rank 3

a = tf.constant(a, dtype=tf.float32)
b = tf.constant(b, dtype=tf.float32)

c = tf.matmul(a, b, adjoint_a=True, adjoint_b=False)  # when axes is [0,0] or [1,1]
d = tf.matmul(a, b, adjoint_a=False, adjoint_b=True)  # when axes is [2,2]
print(c.shape)  # shape=(1,3,3)
print(d.shape)  # shape=(1,2,2)

在您的示例中发生了同样的事情:

a = np.array([[[1, 2, 3],
               [3, 2, 1]]])

b = np.array([[[1, 3, 3],
               [2, 2, 0]]])

q = tf.placeholder(dtype=tf.float32, shape=(None, 2, 3))  
k = tf.placeholder(dtype=tf.float32, shape=(None, 2, 3))  
res1 = tf.keras.backend.batch_dot(q, k, axes=0)
res2 = tf.keras.backend.batch_dot(q, k, axes=1)
res3 = tf.keras.backend.batch_dot(q, k, axes=2)
with tf.Session() as sess:
    feed_dic = {q: a, k: b}
    print(sess.run(res1, feed_dict=feed_dic))
    print(20 * '-')
    print(sess.run(res2, feed_dict=feed_dic))
    print(20 * '-')
    print(sess.run(res3, feed_dict=feed_dic))

答案 1 :(得分:0)

如果您查看https://github.com/tensorflow/tensorflow/blob/a6d8ffae097d0132989ae4688d224121ec6d8f35/tensorflow/python/keras/backend.py#L1437上的源代码,一切将会很清楚

我们可以直接进入line1507

if ndim(x) == 2 and ndim(y) == 2:
    if axes[0] == axes[1]:
      out = math_ops.reduce_sum(math_ops.multiply(x, y), axes[0])
    else:
      out = math_ops.reduce_sum(
          math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else: 
    adj_x = None if axes[0] == ndim(x) - 1 else True
    adj_y = True if axes[1] == ndim(y) - 1 else None
    out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)

如所示,它仅检查adj_xadj_y,而不会将axes参数传递给math_ops.matmul方法。这就是当axes[1,1][2,2]时获得相同结果的原因。

我们可以使用以下代码进行验证:

q = K.ones(shape=range(1, 10))
k = K.ones(shape=range(1, 10))
for i in range(10): print(i, K.batch_dot(q,k,axes=[i,i]))

它将打印

0 Tensor("MatMul_7:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
1 Tensor("MatMul_8:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
2 Tensor("MatMul_9:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
3 Tensor("MatMul_10:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
4 Tensor("MatMul_11:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
5 Tensor("MatMul_12:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
6 Tensor("MatMul_13:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
7 Tensor("MatMul_14:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
8 Tensor("MatMul_15:0", shape=(1, 2, 3, 4, 5, 6, 7, 8, 8), dtype=float32)
9 Tensor("MatMul_16:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)

除了i为8时,其他所有结果都返回相同的结果。