Understanding Cross Entropy Loss

时间:2018-03-25 19:59:43

标签: python machine-learning neural-network loss-function

I see a lot of explanations about CEL or binary cross entropy loss in the context where the ground truth is say, a 0 or 1, and then you get a function like:

def CrossEntropy(yHat, y):
    if yHat == 1:
      return -log(y)
    else:
      return -log(1 - y)

However, I'm confused at how BCE works when your yHat is not a discrete 0 or 1. For example if I want to look at reconstruction loss for an MNIST digit where my ground truths are 0 < yHat < 1, and my predictions are also in the same range, how does this change my function?

EDIT:

Apologies let me give some more context for my confusion. In the PyTorch tutorials on VAEs they use BCE to calculate reconstruction loss, where yhat (as far as I understand, is not discrete). See:

https://github.com/pytorch/examples/blob/master/vae/main.py

The implementation works...but I don't understand how that BCE loss is calculated in this case.

3 个答案:

答案 0 :(得分:3)

Cross entropy measures distance between any two probability distributions. In what you describe (the VAE), MNIST image pixels are interpreted as probabilities for pixels being "on/off". In that case your target probability distribution is simply not a dirac distribution (0 or 1) but can have different values. See the cross entropy definition on Wikipedia.

With the above as a reference, let's say your model outputs a reconstruction for a certain pixel of 0.7. This essentially says that your model estimates p(pixel=1) = 0.7, and accordingly p(pixel=0) = 0.3.
If the target pixels would just be 0 or 1, your cross entropy for this pixel would either be -log(0.3) if the true pixel is 0 or -log(0.7) (a smaller value) if the true pixel is 1.
The full formula would be -(0*log(0.3) + 1*log(0.7)) if the true pixel is 1 or -(1*log(0.3) + 1*log(0.7)) otherwise.

Let's say your target pixel is actually 0.6! This essentially says that the pixel has a probability of 0.6 to be on and 0.4 to be off.
This simply changes the cross entropy computation to -(0.4*log(0.3) + 0.6*log(0.7)).

Finally, you can simply average/sum these per-pixel cross-entropies over the image.

答案 1 :(得分:0)

The cross-entropy loss is only used in classification problems: i.e., where your target (yHat) is discrete. If you have a regression problem instead, something like the mean squared error (MSE) loss would be more appropriate. You can find a variety of losses for the PyTorch library, and their implementations here.

In the case of the MNIST dataset, you actually have a multiclass classification problem (you're trying to predict the correct digit out of 10 possible digits), so the binary cross-entropy loss isn't suitable, and you should the general cross-entropy loss instead.

Regardless, the first step in your investigation should be identifying whether your problem is "classification" or "regression". A loss function suitable for one problem is generally not suitable for the other.

EDIT: You can find a more detailed explanation of the cross-entropy loss in the context of the MNIST problem at the "MNIST for ML Beginners" tutorial on the TensorFlow website.

答案 2 :(得分:0)

You generally shouldn't encode set of classes that isn't binary as values between 0 and 1. In the case of MNIST, if you were to label each digit 0, 0.1, 0.2, etc, this implies that an image of a 2 is more similar to an image of a 0 than an image of a 5, which isn't necessarily true.

One good thing to do is to "one hot encode" your labels instead, as a 10 element array of 0s. Then, set the index corresponding to the digit image to 1.

As mentioned above, you would then use the regular cross-entropy loss function. Your model should then output a vector of conditional probabilities for each sample, corresponding to each possible class. Probably using a softmax function.