想象一下以下设置:
我的数据加载器产生以下形状:(100,2,302,482,3)-目标是将通道轴上的两个输入图像堆叠到(100,302,482,6)。
没有批处理尺寸(因此x的形状为(2,302,482,3)),这非常容易:
# x.shape = (2, 302, 482, 3)
stacked = tf.concat(x, axis=-1)
# stacked.shape = (302, 482, 6)
但是我想在添加批处理维度时可以执行相同的操作。
答案 0 :(得分:0)
我认为,最好的方法是在网络输入之前(tf.concat
都将以相同的方式进行。