SparseTensor * Vector

时间:2017-08-29 12:29:54

标签: tensorflow sparse-matrix

当A是tf.SparseTensor而b是tf.Variable时,如何在tensorflow中实现以下内容?

A = np.arange(5**2).reshape((5,5))
b = np.array([1.0, 2.0, 0.0, 0.0, 1.0])
C = A * b 

如果我尝试相同的表示法,我会得到InvalidArgumentError:提供的索引超出了w.r.t.密集的一面,有广播的形状。

1 个答案:

答案 0 :(得分:1)

*也适用于 SparseTensor ,您的问题似乎与 SparseTensor 本身有关,您可能提供的指标不在你给它的形状范围,考虑这个例子:

A_t = tf.SparseTensor(indices=[[0,6],[4,4]], values=[3.2,5.1], dense_shape=(5,5))

请注意,列索引6大于指定的形状,该列应具有最大5列,这会产生与您所显示的相同的错误:

b = np.array([1.0, 2.0, 0.0, 0.0, 1.0])

B_t = tf.Variable(b, dtype=tf.float32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(A_t * B_t))
  

InvalidArgumentError(参见上面的回溯):提供的索引是   越界w.r.t.密集的一面,有广播的形状

这是一个有效的例子:

A_t = tf.SparseTensor(indices=[[0,3],[4,4]], values=[3.2,5.1], dense_shape=(5,5)) 

b = np.array([1.0, 2.0, 0.0, 0.0, 1.0])
B_t = tf.Variable(b, dtype=tf.float32)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(A_t * B_t))
# SparseTensorValue(indices=array([[0, 3],
#        [4, 4]], dtype=int64), values=array([ 0.       ,  5.0999999], dtype=float32), dense_shape=array([5, 5], dtype=int64))