TensorFlow中的双线性张量积

时间:2015-12-05 20:30:29

标签: tensorflow

我正在努力重新实施this paper,关键操作是双线性张量积。我几乎不知道这意味着什么,但是纸张有一个很好的小图形,我明白了。

enter image description here

关键操作是 e_1 * W * e_2 ,我想知道如何在tensorflow中实现它,因为其余应该很容易。

基本上,给定3D张量W,将其切割成矩阵,对于第j个切片(矩阵),将其两侧乘以 e_1 e_2 ,产生标量,这是结果向量中的第j个条目(此操作的输出)。

所以我想要执行 e_1 ,d维向量, W ,dxdxk张量和 e_2 的产品,另一个d维向量。这个产品能否像现在一样在TensorFlow中简明扼要地表达,还是我必须以某种方式定义自己的操作?

早期编辑

为什么这些张量的乘法不起作用,是否有某种方法可以更明确地定义它以使其有效?

>>> import tensorflow as tf
>>> tf.InteractiveSession()
>>> a = tf.ones([3, 3, 3])
>>> a.eval()
array([[[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],

       [[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],

       [[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]], dtype=float32)
>>> b = tf.ones([3, 1, 1])
>>> b.eval()
array([[[ 1.]],

       [[ 1.]],

       [[ 1.]]], dtype=float32)
>>> 

错误消息是

ValueError: Shapes TensorShape([Dimension(3), Dimension(3), Dimension(3)]) and TensorShape([Dimension(None), Dimension(None)]) must have the same rank

CURRENTLY

事实证明,将两个3D张量相乘不会对tf.matmul起作用,所以tf.batch_matmul会这样做。 tf.batch_matmul也将执行3D张量和矩阵。然后我尝试了3D和矢量:

ValueError: Dimensions Dimension(3) and Dimension(1) are not compatible

2 个答案:

答案 0 :(得分:12)

您可以通过简单的重塑来完成此操作。对于两个矩阵乘法中的第一个,你有k * d,长度d向量与dot product。

这应该是关闭的:

temp = tf.matmul(E1,tf.reshape(Wddk,[d,d*k]))
result = tf.matmul(E2,tf.reshape(temp,[d,k]))

答案 1 :(得分:0)

你可以在W和e2之间执行秩3张量和矢量乘法,产生一个二维数组,然后将结果乘以e1。以下函数利用张量积和张量收缩来定义该乘积(例如W * e3)

import sympy as sp

def tensor3_vector_product(T, v):
    """Implements a product of a rank 3 tensor (3D array) with a 
       vector using tensor product and tensor contraction.

    Parameters
    ----------
    T: sp.Array of dimensions n x m x k

    v: sp.Array of dimensions k x 1

    Returns
    -------
    A: sp.Array of dimensions n x m

    """
    assert(T.rank() == 3)
    # reshape v to ensure a 1D vector so that contraction do 
    # not contain x 1 dimension
    v.reshape(v.shape[0], )
    p = sp.tensorproduct(T, v)
    return sp.tensorcontraction(p, (2, 3))

您可以使用ref中提供的示例验证此乘法。上面的函数收缩了第二和第三轴,在你的情况下我认为你应该收缩(1,2)因为W被定义为d x d x k而不是k x d x d,就像我的情况一样。