我已经使用Keras Functional API构建了u-net架构,但是在使用稀疏分类交叉熵损失函数时遇到了麻烦。我的学习任务是对许多256x256图像进行多类,逐像素分类。预期的输出是256x256的蒙版图像,其整数值为0-31(并非每个蒙版都包含每个类)。我有32个类,所以一键式编码给我一个OOM错误,这就是为什么我不使用分类交叉熵的原因。多数遮罩像素为0(这可能是问题的一部分)。
我一直在亏损= nan。我已经将输入数据归一化为平均值= 0,标准差=1。如果保留原样的掩码,则精度约为0.97,输出掩码均为1(这显然是错误的)。如果我在进行训练之前在所有蒙版上加1,则精度为0。我在最后一个卷积层中使用带有SoftMax的relu激活。
似乎问题可能与我的输出数据的格式有关,所以我的主要问题是,稀疏分类交叉熵应采用哪种格式?我应该将遮罩值标准化为0-1吗?另外,还有其他可用于训练的损失函数或准确性指标吗?就多类分类而言,我所知道的唯一功能是分类交叉熵。如果需要,我可以提供有关我的数据,网络等的其他信息。