Pytorch:如何在一批中处理不同的输入大小?

时间:2020-02-21 16:34:10

标签: deep-learning set pytorch

我正在点云上实现与DeepSets架构紧密相关的内容:

https://arxiv.org/abs/1703.06114

这意味着我正在使用一组输入(坐标),完全连接的层分别处理每个输入,然后对它们进行平均池化(然后进行进一步处理)。

每个样本 i 的输入是形状为[L_i, 3]的张量,其中L_i是点数,最后一个维度是3,因为每个点具有x,y,z坐标。至关重要的是,L_i取决于示例。因此,每个实例的点数不同。当我将所有内容放入批处理中时,当前所有i的输入形状均为[B, L, 3],其中L大于L_i。各个样本都填充有0。问题是网络不会忽略0,而是将其处理并送入平均池中。相反,我希望平均池仅考虑实际点(而不填充0)。我确实有另一个存储长度[L_1, L_2, L_3, L_4...]的数组,但是我不确定如何使用它。

我的问题是:您如何以最优雅的方式处理一批不同的输入大小?

这是模型的定义方式:

encoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 128))
x = self.encoder(x)
x = x.max(dim=1)[0]
decoder = ...

0 个答案:

没有答案