为稀疏 - 密集乘法提供张量流稀疏矩阵。获得以下错误:TypeError:输入必须是SparseTensor

时间:2016-12-12 16:57:04

标签: tensorflow

我正在尝试使用tensorflow tf.sparse_tensor_dense_matmul(X, W1)。 X定义为tf.placeholder:

X = tf.placeholder("float", [None, size])

W是tf.Variable

在进入dict时,我正在传递张量流稀疏矩阵。但是我收到了错误:

  

TypeError:输入必须是SparseTensor。

如何让 sparse_tensor_dense_matmul 模块知道我将以稀疏张量传递?

1 个答案:

答案 0 :(得分:2)

要通过占位符传递SparseTensor,您可以使用sparse_placeholder

sparse_place = tf.sparse_placeholder(tf.float64)
mul_result = tf.sparse_tensor_dense_matmul(sparse_place, some_dense_tensor)

您可以按如下方式使用它:

dense_version = tf.sparse_tensor_to_dense(sparse_place)
sess.run(dense_version, feed_dict={sparse_place: tf.SparseTensorValue([[0,1], [2,2]], [1.5, 2.9], [3, 3])})
Out: array([[ 0. ,  1.5,  0. ],
            [ 0. ,  0. ,  0. ],
            [ 0. ,  0. ,  2.9]])

或者,您可以为值,形状和索引创建三个单独的占位符,例如:

indices = tf.placeholder(tf.int64)
shape = tf.placeholder(tf.int64)
values = tf.placeholder(tf.float64)
sparse_tensor = tf.SparseTensor(indices, shape, values)
sess.run(sparse_tensor, feed_dict={shape: [3, 3], indices: [[0,1], [2,2]], values: [1.5, 2.9]})