我正在尝试在tensorflow对象检测API中实现自己的损失函数。我已按照以下步骤操作。
1)将损失添加到core / losses.py(只是共享骨架)
"de-DE"
2。将此定义添加到protos / losses.proto
class ClassBalancedSigmoidFocalClassificationLoss(Loss):
def __init__(self, gamma=2.0, alpha=0.25):
super(ClassBalancedSigmoidFocalClassificationLoss, self).__init__()
self._alpha = alpha
self._gamma = gamma
def _compute_loss(self,
prediction_tensor,
target_tensor,
weights,
class_indices=None):
return
3。通过以下方式更新protos配置:
message ClassificationLoss {
oneof classification_loss {
WeightedSigmoidClassificationLoss weighted_sigmoid = 1;
WeightedSoftmaxClassificationLoss weighted_softmax = 2;
WeightedSoftmaxClassificationAgainstLogitsLoss weighted_logits_softmax = 5;
BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3;
SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4;
ClassBalancedSigmoidFocalClassificationLoss my_loss = 6;
}
}
但是,我在第3步遇到以下错误
./bin/protoc object_detection/protos/*.proto --python_out=.
我正在使用TensorFlow 1.15。请帮忙。