我有一个形状为(dim1, dim2)
的矩阵A,以及一批形状为(batch_size, dim2, dim3)
的矩阵。
如何将矩阵与批处理中的每个矩阵相乘?平铺矩阵A似乎消耗了太多内存。
答案 0 :(得分:0)
einsum
和numpy
中提供的 tensorflow
应该可以满足您的需求:
import numpy as np
batch_size = 5
dim1 = 7
dim2 = 2
dim3 = 3
A = np.random.rand(batch_size, dim1, dim2)
B = np.random.rand(batch_size, dim2, dim3)
C = np.einsum('kl,ilm->ikm',A,B)
print(C.shape)
Out[9]: (5, 7, 3)
import tensorflow as tf
a = tf.constant(A) # reusing numpy arrays
b = tf.constant(B)
op = tf.einsum('kl,ilm->ikm',a,b)
with tf.Session() as sess:
print(sess.run(op) - C) # prints a zero array