tensoflow稀疏张量高效配料

时间:2018-05-07 20:46:40

标签: tensorflow

我有一个方法(如下所示)从tensorflow SparseTensorValue获取批处理。然而,这种方法相当慢(对于一批32号的批次来说是10-20秒),这是有问题的,因为它被称为数千次。

def get_batch(index, tensors, batch_size, nItems):
    xs, ys = tensors
    begin = (index * batch_size)
    end = min((index+1)*batch_size, nItems)
    y_b = ys[begin:end]

    (inds, vals, dsize) = xs
    nInds = [[ind[0] - begin, ind[1]] for ind in inds if begin <= ind[0] < end]
    nInds = np.array(nInds)
    nVals = vals[:nInds.shape[0]]
    nDsize = (end - begin, dsize[1])
    x_b = tf.SparseTensorValue(nInds, nVals, nDsize)
    return (x_b, y_b)

有没有办法让这种方法更有效率?

1 个答案:

答案 0 :(得分:0)

我建议您使用tf.data编写输入管道,然后如果有的话,可以将此重新分配到另一个核心,而不是阻止主线程。