3D * 2D矩阵乘法而不更改批处理尺寸

时间:2018-12-09 15:42:24

标签: python tensorflow

由于一些tensorrt问题,我试图以其他方式实现tensordot。

在tf.tensordot op中,在3d * 2d matmul过程中进行了批量大小修改。

M = tf.random_normal((batch_size, n, m))  # (3,6,9)
N = tf.random_normal((m, p)) # (9,9)

MT = tf.reshape(M, [batch_size*n, m]) # (18,9)
MTN = tf.matmul(M_T, N) # (18,9)

MN = tf.reshape(MTN, [batch_size, n, p]) # (3,6,9)

但是我想要3d * 2d matmul,而不更改批量大小尺寸。 有办法吗?

0 个答案:

没有答案