在TF对象检测中定义自定义丢失

时间:2020-04-11 07:11:13

标签: tensorflow tensorflow-model-garden

我正在尝试在tensorflow对象检测API中实现自己的损失函数。我已按照以下步骤操作。

1)将损失添加到core / losses.py(只是共享骨架)

"de-DE"

2。将此定义添加到protos / losses.proto

class ClassBalancedSigmoidFocalClassificationLoss(Loss):
  def __init__(self, gamma=2.0, alpha=0.25):
    super(ClassBalancedSigmoidFocalClassificationLoss, self).__init__()
    self._alpha = alpha
    self._gamma = gamma

  def _compute_loss(self,
                    prediction_tensor,
                    target_tensor,
                    weights,
                    class_indices=None):
    return

3。通过以下方式更新protos配置:

message ClassificationLoss {
  oneof classification_loss {
    WeightedSigmoidClassificationLoss weighted_sigmoid = 1;
    WeightedSoftmaxClassificationLoss weighted_softmax = 2;
    WeightedSoftmaxClassificationAgainstLogitsLoss weighted_logits_softmax = 5;
    BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3;
    SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4;
    ClassBalancedSigmoidFocalClassificationLoss my_loss = 6;
  }
}

但是,我在第3步遇到以下错误

./bin/protoc object_detection/protos/*.proto --python_out=.

我正在使用TensorFlow 1.15。请帮忙。

0 个答案:

没有答案