正确地在matmul中广播批量维度

时间:2017-10-09 19:16:39

标签: python tensorflow vectorization matrix-multiplication

所以我试图将两个张量AB乘以A = [1, n, n*2]B = [bs, n*2, m],这应该会导致C = [64, n, m]

在我看来,这应该像tf.matmul(A,B)一样简单,它应该在批量维度上进行广播。

我收到错误:

  

两个形状中的尺寸0必须相等,但是为1和64   ' seq2seq /解码器/注意/ MATMUL' (op:' BatchMatMul')带输入   形状:[1,256,512],[64,512,120]

我也试过tf.multiply(应该支持广播)但是有同样的错误。

一个简单的脏黑客就是:

tf.stack([tf.matmul(A, s) for s in tf.unstack(B, axis=0)], axis=0)

但这似乎相当可怕。 我看到很多使用tf.reshape的答案,根据我以前的经验,我想避免。其他答案提示tf.einsum但我以前从未使用过这个问题,我不能帮助,但认为必须有一个更容易解决这个问题的方法。

所以我想我的问题是以最佳方式解决肮脏黑客的最佳选择。

由于

0 个答案:

没有答案