限制火炬中的多标签分类

时间:2019-12-01 14:33:44

标签: neural-network deep-learning constraints pytorch

具有一个具有N个S型输出的模型(LSTM),我需要限制此输出以具有偶数个正标号。

允许的预测示例:

y_hat = [1,0,1,1,0,0,1] # positive = 4 % 2 = 0
y_hat = [1,1,0,0,0,0,0] # positive = 2 % 2 = 0
y_hat = [1,0,1,1,1,1,1] # positive = 6 % 2 = 0

不允许的预测示例:

y_hat = [1,1,1,1,0,0,1] # positive = 5 % 2 = 1 (not even)
y_hat = [1,1,0,0,1,0,0] # positive = 3 % 2 = 1 (not even)
y_hat = [1,1,1,1,1,1,1] # positive = 7 % 2 = 1 (not even)

我尝试实现custum损失功能,但没有成功。

非常感谢您

0 个答案:

没有答案