如何在张量流中有效地交换张量轴?

时间:2019-08-26 02:58:02

标签: numpy tensorflow deep-learning

我必须使用tf.transpose交换张量轴以进行批处理矩阵乘法(如下代码所示)。

张量输入_a:形状[10000,10000]

张量input_b:形状[batch_size,10000,10]

张量输出:形状[batch_size,10000、10]

# reshape_input_b: shape [10000, batch_size, 10]
transpose_input_b = tf.transpose(input_b, [1, 0, 2])

# transpose_input_b : shape [10000, batch_size * 10]
reshape_input_b = tf.reshape(transpose_input_b , [10000, -1])

# ret: shape [10000, batch_size * 10]
ret = tf.matmul(input_a, reshape_input_b, a_is_sparse = True)

# reshape_ret: [10000, batch_size, 10]
reshape_ret = tf.reshape(ret, [10000, -1, 10])

# output : [batch_size, 10000, 10]
output = tf.transpose(reshape_ret, [1, 0, 2])

但是,它看起来非常慢。我在tf.transpose的文档页面中注意到了这一点:

  

在numpy中,转置是内存有效的恒定时间操作,因为它们只是简单地返回经过调整步幅的相同数据的新视图。

     

TensorFlow不支持跨步,因此转置会返回一个新的张量,其中的每个元素都经过排列

所以,我认为这可能是我的代码运行缓慢的原因?有什么方法可以交换张量轴,还是有效地进行批矩阵乘法?

0 个答案:

没有答案