我尝试使用自定义python损失层。当我在网上查看几个例子时,例如:
Euclidean loss layer,Dice loss layer,
我注意到一个变量&self.diff'总是在'转发'中分配。特别是对于Dice损失层,
self.diff[...] = bottom[1].data
我想知道是否有任何理由必须在forward
中引入此变量,或者我可以使用bottom[1].data
来访问基本事实标签?
此外,top[0].reshape(1)
中的reshape
有什么意义,因为根据forward
的定义,损失输出本身就是标量。
答案 0 :(得分:1)
您需要设置图层的 diff 属性,以实现整体一致性和数据通信协议;它可以在类中的其他位置以及丢失层对象出现的任何位置。 底部是一个本地参数,在同一表单的其他地方不可用。
通常,代码可扩展用于各种应用程序和更复杂的计算;重塑是其中的一部分,确保返回的值是标量,即使有人扩展输入以使用向量或矩阵。