所以我试图将两个张量A
和B
乘以A = [1, n, n*2]
和B = [bs, n*2, m]
,这应该会导致C = [64, n, m]
。
在我看来,这应该像tf.matmul(A,B)
一样简单,它应该在批量维度上进行广播。
我收到错误:
两个形状中的尺寸0必须相等,但是为1和64 ' seq2seq /解码器/注意/ MATMUL' (op:' BatchMatMul')带输入 形状:[1,256,512],[64,512,120]
我也试过tf.multiply
(应该支持广播)但是有同样的错误。
一个简单的脏黑客就是:
tf.stack([tf.matmul(A, s) for s in tf.unstack(B, axis=0)], axis=0)
但这似乎相当可怕。
我看到很多使用tf.reshape
的答案,根据我以前的经验,我想避免。其他答案提示tf.einsum
但我以前从未使用过这个问题,我不能帮助,但认为必须有一个更容易解决这个问题的方法。
所以我想我的问题是以最佳方式解决肮脏黑客的最佳选择。
由于