如何使用Tensorflow将张量的每一列乘以另一列的所有列?

时间:2018-06-19 13:13:49

标签: tensorflow linear-algebra matrix-multiplication

ab定义为张量:

a = tf.constant([[1, 4],
                 [2, 5],
                 [3, 6]], tf.float32)

b = tf.constant([[10, 40],
                 [20, 50],
                 [30, 60]], tf.float32)

我正在寻找一种方法,将a的每一列乘以b的所有列,得出如下结果:

[[10,  40,  40, 160],
 [40, 100, 100, 250],
 [90, 180, 180, 360]]

我需要可以在具有任意列数(> 2)的张量上执行的操作。

我已经开发了一种可以在循环中使用的解决方案。您可以检出here

感谢您的关注。

2 个答案:

答案 0 :(得分:3)

我想念什么吗?为什么不只是

import tensorflow as tf

a = tf.constant([[1, 4],
                 [2, 5],
                 [3, 6]], tf.float32)

b = tf.constant([[10, 40],
                 [20, 50],
                 [30, 60]], tf.float32)

h_b, w_a = a.shape.as_list()[:2]
w_b = a.shape.as_list()[1]

c = tf.einsum('ij,ik->ikj', a, b)
c = tf.reshape(c,[h_b, w_a * w_b])

with tf.Session() as sess:
    print(sess.run(c))

编辑:添加foo.shape.as_list()

答案 1 :(得分:0)

您可以尝试以下方法:

import tensorflow as tf

a = tf.constant([[1, 4],
                 [2, 5],
                 [3, 6]], tf.float32)

b = tf.constant([[10, 40],
                 [20, 50],
                 [30, 60]], tf.float32)

a_t = tf.transpose(a)
b_t = tf.transpose(b)

c = tf.transpose(tf.stack([a_t[0] * b_t[0],
                           a_t[0] * b_t[1],
                           a_t[1] * b_t[0],
                           a_t[1] * b_t[1]]))

with tf.Session() as sess:
    print(sess.run(c))

但是,对于较大的矩阵,必须调整索引。