二维和三维张量的乘法

时间:2019-05-06 01:29:08

标签: python tensorflow

有一个二维张量a[m,n]和一个三维张量b[k,n,h]。我应该使用什么API将二维张量乘以三维张量以获得三维张量c[k,m,h]

实际上我可以通过:

import tensorflow as tf
import tensorly as tl
x = tf.constant([[[1,2],[3,7],[8,9]],
                 [[4,5],[6,10],[11,12]]],tf.float32)
a = tf.constant([[-0.70711,0.57735],
                 [0.0000,0.57735],
                 [0.70711,0.57735]])
reshape_A = tf.reshape(x, [2,6])

re = tf.reshape(tf.matmul(a, reshape_A), [3, 3, 2])

with tf.Session() as sess:
    print(sess.run(re))
    re = re.eval()

但是有没有更简单的方法?

1 个答案:

答案 0 :(得分:0)

您可以使用tensorly.tenalg.contract。 例如:

import tensorly as tl
import numpy as np
tl.set_backend('tensorflow')

k = 2; m = 3; n = 5; h = 4

A = tl.tensor(np.random.random((m, n)))
B = tl.tensor(np.random.random((k, n, h)))

res = tl.tenalg.contract(A, 1, B, 1)