假设我有一个形状为[n,h,w,c]的批处理,以及要在0-9和10 Conv2D convs
范围内的n个索引列表,我想将它们应用于数据取决于列表中的索引。索引列表随批次而变化。
例如。输入x,批处理大小为4,索引为l = [1,5,1,9],我想计算[convs[l[0]](x[0]), convs[l[1]](x[1]), convs[l[2]](x[2]), convs[l[3]](x[3])]
幼稚的解决方案是根据l
计算每个组合并收集。但是,这需要10倍的内存量。有没有更好的解决方案来解决这个问题?
答案 0 :(得分:0)
一种“ hacky”解决方案是将输入的维度从[n, h, w, c]
扩展到[1, n, h, w, c]
,然后使用Conv3D
而不是内核形状为[1, x, y]
如果您分别定义了权重(也可以使用layer.weights
获得权重),则可以类似地将其堆叠在第0维上,并通过tf.nn.conv3d
使用它们。