计算二进制矩阵和2D坐标列表之间的损失

时间:2020-11-02 11:44:29

标签: pytorch

我有一个PyTorch模型,该模型会产生一个坐标列表,例如

>>> coords = model(inputs)
tensor([[1.1000, 3.8000, 0.2000, 4.4000, 2.5000],
        [3.4000, 3.2000, 1.3000, 0.9000, 1.7000]])

其中coords[0]是x坐标,coords[1]是y坐标,代表像素值。这些像素值形成一个二进制“图像”:

>>> indexes = coords.round().long()
>>> torch.sparse.FloatTensor(indexes, torch.ones(indexes.shape[1]), torch.Size([5, 5])).to_dense().T
tensor([[0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])

我想优化模型,使输出矩阵与二进制目标图像匹配。

但是,索引步骤不可区分。 .long()操作导致

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

向后走

是否可以进行“ 2D一键编码”或其他一些操作,以可区分的方式将浮点坐标列表转换为二进制矩阵?

0 个答案:

没有答案