在Lambda层中处理批处理的最佳方法是什么?

时间:2019-08-07 13:41:53

标签: python tensorflow keras

我用Keras创建了一个神经网络,并添加了Lambda层来执行一些计算,但是在推理方面表现不佳。

我能够使用一批一个输入成功地进行推断,并添加了一个循环来处理多个输入。一切正常,但性能有些差。我认为使用更大的批次将使事情变得更快。我的问题是我是否正确处理了批处理(是否真的需要使用另一个循环?),因为我还没有找到任何更深入的keras或tensorflow文档来处理该主题。 下面的代码结构类似于我在Lambda层中使用的结构。

def GenericFunc(x, batch=10, channels=64):
    y, group = [], []
    for i in range(batch):
        for j in range(channels):
            y.append(backend.sum(x[0, :, :, j]))
        group.append(tf.convert_to_tensor(y, dtype=np.float32))
        y = []
    yy = backend.stack(group, axis=0)
    tensor_stack = backend.reshape(yy, [batch,channels])
    return tensor_stack

任何建议都会受到欢迎!

1 个答案:

答案 0 :(得分:0)

请勿使用循环。张量用于张量操作。

def GenericFunc(x):
    y = backend.sum(x, axis=1)
    y = backend.sum(y, axis=1)
    return y

可能也与

一起使用
def GenericFunc(x):
    return backend.sum(x, axis=[1,2])