使用Conv2d批量处理Tensorflow开关盒

时间:2020-02-10 20:00:20

标签: python tensorflow keras tf.keras

假设我有一个形状为[n,h,w,c]的批处理,以及要在0-9和10 Conv2D convs范围内的n个索引列表,我想将它们应用于数据取决于列表中的索引。索引列表随批次而变化。

例如。输入x,批处理大小为4,索引为l = [1,5,1,9],我想计算[convs[l[0]](x[0]), convs[l[1]](x[1]), convs[l[2]](x[2]), convs[l[3]](x[3])]

幼稚的解决方案是根据l计算每个组合并收集。但是,这需要10倍的内存量。有没有更好的解决方案来解决这个问题?

1 个答案:

答案 0 :(得分:0)

一种“ hacky”解决方案是将输入的维度从[n, h, w, c]扩展到[1, n, h, w, c],然后使用Conv3D而不是内核形状为[1, x, y]

如果您分别定义了权重(也可以使用layer.weights获得权重),则可以类似地将其堆叠在第0维上,并通过tf.nn.conv3d使用它们。