在Tensorflow 1.4中将矩阵与一批矩阵相乘

时间:2018-03-07 22:32:26

标签: tensorflow

我有一个形状为(dim1, dim2)的矩阵A,以及一批形状为(batch_size, dim2, dim3)的矩阵。
如何将矩阵与批处理中的每个矩阵相乘?平铺矩阵A似乎消耗了太多内存。

1 个答案:

答案 0 :(得分:0)

einsumnumpy中提供的

tensorflow应该可以满足您的需求:

import numpy as np
batch_size = 5
dim1 = 7
dim2 = 2
dim3 = 3

A = np.random.rand(batch_size, dim1, dim2)
B = np.random.rand(batch_size, dim2, dim3)   
C = np.einsum('kl,ilm->ikm',A,B)
print(C.shape)
Out[9]: (5, 7, 3)

import tensorflow as tf
a = tf.constant(A) # reusing numpy arrays
b = tf.constant(B)
op = tf.einsum('kl,ilm->ikm',a,b)
with tf.Session() as sess:       
    print(sess.run(op) - C) # prints a zero array