如何在张量流中将2d张量与3d张量叠加?

时间:2018-07-02 21:13:48

标签: python tensorflow

numpy中,您可以将2d数组与3d数组相乘,如下例所示:

>>> X = np.random.randn(3,5,4) # [3,5,4]
... W = np.random.randn(5,5) # [5,5]
... out = np.matmul(W, X) # [3,5,4]

据我了解,np.matmul()W沿X的第一个维度广播。但是在tensorflow中是不允许的:

>>> _X = tf.constant(X)
... _W = tf.constant(W)
... _out = tf.matmul(_W, _X)

ValueError: Shape must be rank 2 but is rank 3 for 'MatMul_1' (op: 'MatMul') with input shapes: [5,5], [3,5,4].

那么np.matmul()上面的tensorflow有什么等效功能?在tensorflow中将2d张量与3d张量相乘的最佳实践是什么?

5 个答案:

答案 0 :(得分:3)

以下是来自tensorflow XLA broadcasting semantics

  

XLA语言尽可能严格和显式,避免了隐式和“魔术”功能。这样的功能可能会使一些计算的定义更容易些,但代价是要在用户代码中增加更多的假设,而这些假设在长期内将很难更改。

因此Tensorflow不提供内置的广播功能。

但是,它确实提供了可以像张量一样重塑张量的功能。此操作称为tf.tile

签名如下:

tf.tile(input, multiples, name=None)
  

此操作通过创建一个新的张量   复制输入倍数输出张量的第i个维度   具有input.dims(i)*多个[i]元素,并且input的值是   沿第i个维度复制了多个[i]次。

答案 1 :(得分:3)

您可以改用tensordot

tf.transpose(tf.tensordot(_W, _X, axes=[[1],[1]]),[1,0,2])

答案 2 :(得分:2)

尝试使用tf.tile来增加乘法之前矩阵的等级。 numpy的自动广播功能似乎未在tensorflow中实现。您必须手动执行。

W_T = tf.tile(tf.expand_dims(W,0),[3,1,1])

这应该可以解决问题

import numpy as np
import tensorflow as tf

X = np.random.randn(3,4,5)
W = np.random.randn(5,5)

_X = tf.constant(X)
_W = tf.constant(W)
_W_t = tf.tile(tf.expand_dims(_W,0),[3,1,1])

with tf.Session() as sess:
    print(sess.run(tf.matmul(_X,_W_t)))

答案 3 :(得分:0)

您也可以使用tf.einsum避免平铺张量:

tf.einsum("ab,ibc->iac", _W, _X)

完整示例:

import numpy as np
import tensorflow as tf

# Numpy-style matrix multiplication:
X = np.random.randn(3,5,4)
W = np.random.randn(5,5)
np_WX = np.matmul(W, X)

# TensorFlow-style multiplication:
_X = tf.constant(X)
_W = tf.constant(W)
_WX = tf.einsum("ab,ibc->iac", _W, _X)

with tf.Session() as sess:
    tf_WX = sess.run(_WX)

# Check that the results are the same:
print(np.allclose(np_WX, tf_WX))

答案 4 :(得分:0)

在这里,我将使用keras后端K.dot和tensorflow tf.transpose。 第一次交换3 D张量的内部暗淡

X=tf.transpose(X,perm=[0,-1,1]) # X shape=[3,4,5]

现在像这样繁殖

out=K.dot(X,W) # out shape=[3,4,5]

现在再次交换轴

out = tf.transpose(out,perm=[0,-1,1]) # out shape=[3,5,4]

上述解决方案可以节省内存,而无需花费W