标签: python tensorflow deep-learning lstm sampling
我正在尝试处理Quora问题对数据集,以对它们是否重复进行分类。我正在使用BI-LSTM和三重损失作为损失函数。三重损失应该使重复的问题更加接近,而将非重复的问题推远。经过几次训练后,我计算出了l2距离,l2距离的分布如下:
对于重复的样品:
对于非重复问题:
从图片中可以明显看出,存在很多错误分类。我不确定如何继续优化网络。我使用的批次大小低至32。我认为这应该在训练时照顾异常样本。我还有其他批量采样策略可以用来优化它吗?