我必须使用tf.transpose
交换张量轴以进行批处理矩阵乘法(如下代码所示)。
张量输入_a:形状[10000,10000]
张量input_b:形状[batch_size,10000,10]
张量输出:形状[batch_size,10000、10]
# reshape_input_b: shape [10000, batch_size, 10]
transpose_input_b = tf.transpose(input_b, [1, 0, 2])
# transpose_input_b : shape [10000, batch_size * 10]
reshape_input_b = tf.reshape(transpose_input_b , [10000, -1])
# ret: shape [10000, batch_size * 10]
ret = tf.matmul(input_a, reshape_input_b, a_is_sparse = True)
# reshape_ret: [10000, batch_size, 10]
reshape_ret = tf.reshape(ret, [10000, -1, 10])
# output : [batch_size, 10000, 10]
output = tf.transpose(reshape_ret, [1, 0, 2])
但是,它看起来非常慢。我在tf.transpose
的文档页面中注意到了这一点:
在numpy中,转置是内存有效的恒定时间操作,因为它们只是简单地返回经过调整步幅的相同数据的新视图。
TensorFlow不支持跨步,因此转置会返回一个新的张量,其中的每个元素都经过排列。
所以,我认为这可能是我的代码运行缓慢的原因?有什么方法可以交换张量轴,还是有效地进行批矩阵乘法?