使用SparseTensor,使用tf.sparse_matmul进行矩阵乘法失败

时间:2017-07-14 03:22:05

标签: python tensorflow

为什么这不起作用:

pl_input = tf.sparse_placeholder('float32',shape=[None,30])
W = tf.Variable(tf.random_normal(shape=[30,1]), dtype='float32')
layer1a = tf.sparse_matmul(pl_input, weights, a_is_sparse=True, b_is_sparse=False)

错误消息是

TypeError: Failed to convert object of type <class 'tensorflow.python.framework.sparse_tensor.SparseTensor'> to Tensor. Contents: SparseTensor(indices=Tensor("Placeholder_11:0", shape=(?, ?), dtype=int64), values=Tensor("Placeholder_10:0", shape=(?,), dtype=float32), dense_shape=Tensor("Placeholder_9:0", shape=(?,), dtype=int64)). Consider casting elements to a supported type.

我希望创建一个我从中检索批次的SparseTensorValue,然后将批处理提供给pl_input。

1 个答案:

答案 0 :(得分:3)

TL; DR

使用tf.sparse_tensor_dense_matmul代替tf.sparse_matmul;使用tf.nn.embedding_lookup_sparse查看documentation替代方案。

关于稀疏矩阵和SparseTensors

问题并非针对sparse_placeholder,而是由于张量流的术语混淆。

你有稀疏的矩阵。然后你有SparseTensor。两者都是相关但不同的概念。

  • SparseTensor是一个对其值进行索引的结构,可以有效地表示稀疏矩阵或张量。
  • 稀疏矩阵是主要填充0的矩阵。在tensorflow的文档中,通常引用SparseTensor,而是引用主要由Tensor填充的普通旧0

因此,重要的是要查看函数的预期类型,以便弄清楚。

例如,在the documentation of tf.matmul中,操作数必须是普通Tensor而不是SparseTensor s,与xxx_is_sparse标志的值无关,这解释了错误。当这些标志为True时,tf.sparse_matmul实际上期望的是(密集的)Tensor。换句话说,这些标志用于some optimization purposes而不是输入类型约束。 (顺便说一下,这些优化似乎只对rather larger matrices有用。