我们有一个Nxdxd内核张量,我们正在尝试按Nxdx256张量缩放每个对应矩阵中的行(这应该导致Nxdxdx256张量)。然后,我们在第3维上进行最大约简,得到一个Nxdx256矩阵。
或者,我们实质上是在尝试批矩阵乘法,但是要进行最大缩减而不是总和缩减。
截至目前,通过我们的实施,该模型大约需要2天才能完成实验,而更早的时间最多需要3个小时(对模型的唯一更改是以前我们只需要运行批处理矩阵乘法)。张量流中对此有更有效的实现吗?
# kernel is the Nxdxd tensor and X is the Nxdx256 tensor
kern_normed = tf.einsum('abij,aik->abjk', tf.matrix_diag(kernel), X)
return tf.reduce_max(kern_normed, axis=2)