我有一个PyTorch模型,该模型会产生一个坐标列表,例如
>>> coords = model(inputs)
tensor([[1.1000, 3.8000, 0.2000, 4.4000, 2.5000],
[3.4000, 3.2000, 1.3000, 0.9000, 1.7000]])
其中coords[0]
是x坐标,coords[1]
是y坐标,代表像素值。这些像素值形成一个二进制“图像”:
>>> indexes = coords.round().long()
>>> torch.sparse.FloatTensor(indexes, torch.ones(indexes.shape[1]), torch.Size([5, 5])).to_dense().T
tensor([[0., 0., 0., 0., 0.],
[1., 0., 0., 0., 1.],
[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 1.],
[0., 0., 0., 0., 0.]])
我想优化模型,使输出矩阵与二进制目标图像匹配。
但是,索引步骤不可区分。 .long()
操作导致
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
向后走
是否可以进行“ 2D一键编码”或其他一些操作,以可区分的方式将浮点坐标列表转换为二进制矩阵?