我有一个方法(如下所示)从tensorflow SparseTensorValue获取批处理。然而,这种方法相当慢(对于一批32号的批次来说是10-20秒),这是有问题的,因为它被称为数千次。
def get_batch(index, tensors, batch_size, nItems):
xs, ys = tensors
begin = (index * batch_size)
end = min((index+1)*batch_size, nItems)
y_b = ys[begin:end]
(inds, vals, dsize) = xs
nInds = [[ind[0] - begin, ind[1]] for ind in inds if begin <= ind[0] < end]
nInds = np.array(nInds)
nVals = vals[:nInds.shape[0]]
nDsize = (end - begin, dsize[1])
x_b = tf.SparseTensorValue(nInds, nVals, nDsize)
return (x_b, y_b)
有没有办法让这种方法更有效率?
答案 0 :(得分:0)
我建议您使用tf.data
编写输入管道,然后如果有的话,可以将此重新分配到另一个核心,而不是阻止主线程。