tf.stack无法堆叠sparseTensors

时间:2019-10-04 15:28:22

标签: python tensorflow tensorflow-datasets

无法堆叠2个SparseTensors。相同的代码可用于to_dense,但我认为数据集代码可直接用于SparseTensors?

我正在尝试取2个sparseTensors(每个代表一个样本)并将它们堆叠-一批2个。 看来我需要调用tf.sparse.to_dense-才能取回数据集。

以下代码可以正常工作:

import tensorflow as tf
tf.enable_eager_execution()

sparseTensor1 = tf.SparseTensor(indices=[[0, 0, 0], [1, 1, 0]], values=[9, 21], dense_shape=[2, 2, 2])
sparseTensor2 = tf.SparseTensor(indices=[[0, 0, 0], [1, 1, 0], [1,1,1]], values=[4, 33, 99], dense_shape=[2, 2, 2])

"""
This works: but is inefficient:
"""
denseTensor1 = tf.sparse.to_dense(sparseTensor1) 
denseTensor2 = tf.sparse.to_dense(sparseTensor2)

stackedDenseTensors = tf.stack([denseTensor1, denseTensor2])
damonsDatasetFromDense = tf.data.Dataset.from_tensors(stackedDenseTensors)


iterator1 = damonsDatasetFromDense.make_one_shot_iterator()
for next_batch in iterator1:
    tf.print(next_batch)

""" 
however if I try without the tf.sparse.to_dense code I cannot stack:
"""
sparseDataSet1 = tf.data.Dataset.from_tensors(sparseTensor1)
sparseDataSet2 = tf.data.Dataset.from_tensors(sparseTensor2)
stackedSparseTensors = tf.stack([sparseTensor1, sparseTensor2])

同样,我认为Dataset API可以直接处理稀疏数据。

相反,我从堆栈中获得了以下内容: ValueError:尝试将类型()不支持的值()转换为张量。“”“

0 个答案:

没有答案