使用稀疏矩阵创建张量流数据集

时间:2018-08-15 21:10:04

标签: python tensorflow tensorflow-datasets

我的数据集是一个稀疏矩阵,我正在尝试使用tf.dataset对其进行馈送。我的以下代码适用于小型矩阵,但适用于大型矩阵(大量列),但无法运行ValueError: GraphDef cannot be larger than 2GB.,我知道这是因为tf.SparseTensor创建了许多tf.constant操作。将稀疏矩阵输入tf.dataset而不将其转换为SparseTensor的正确方法是什么?

data_tensor = tf.SparseTensor(
    indices=[[0,0], [0,4], [1,1], [1, 50]],
    values=np.ones(4, dtype=np.float32),
    dense_shape=(2, 100)
)
dataset = tf.data.Dataset.from_tensor_slices(data_tensor)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
input_ph = tf.sparse_tensor_to_dense(iterator.get_next())
sum_op = tf.reduce_sum(input_ph)
sess.run(sum_op)

0 个答案:

没有答案