我正在点云上实现与DeepSets架构紧密相关的内容:
https://arxiv.org/abs/1703.06114
这意味着我正在使用一组输入(坐标),完全连接的层分别处理每个输入,然后对它们进行平均池化(然后进行进一步处理)。
每个样本 i 的输入是形状为[L_i, 3]
的张量,其中L_i
是点数,最后一个维度是3
,因为每个点具有x,y,z坐标。至关重要的是,L_i
取决于示例。因此,每个实例的点数不同。当我将所有内容放入批处理中时,当前所有i的输入形状均为[B, L, 3]
,其中L
大于L_i
。各个样本都填充有0。问题是网络不会忽略0,而是将其处理并送入平均池中。相反,我希望平均池仅考虑实际点(而不填充0)。我确实有另一个存储长度[L_1, L_2, L_3, L_4...]
的数组,但是我不确定如何使用它。
我的问题是:您如何以最优雅的方式处理一批不同的输入大小?
这是模型的定义方式:
encoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 128))
x = self.encoder(x)
x = x.max(dim=1)[0]
decoder = ...