使用Keras使用Poisson采样标签提高MLP性能(用于多类别分类)

时间:2019-01-11 15:41:54

标签: keras neural-network poisson multiclass-classification

我正在尝试使用完全连接的神经网络或多层感知器执行多类分类:我的训练数据(X)是长度相等的不同DNA字符串。这些序列中的每个序列都有一个与之关联的浮点值(例如t_X),我可以通过以下方式使用它们为数据模拟标签(y)。 y〜np.random.poisson(常数* t_X)

在训练了Keras模型后(请参见下文),我对预测的标签和测试标签进行了直方图处理,我面临的问题是我的模型似乎对许多序列进行了错误分类,请参见下面的链接图像。

Histogram link

我的训练数据如下:

X , Y  
CTATTACCTGCCCACGGTAAAGGCGTTCTGG,    1
TTTCTGCCCGCGGCCTGGCAATTGATACCGC,    6
TTTTTACACGCCTTGCGTAAAGCGGCACGGC,    4
TTGCTGCCTGGCCGATGGTCTATGCCGCTGC,    7

我一口气编码我的Y,我的X序列变成张量的张量:(批大小,序列长度,字符数),这些数字大约是10,000 x 50 x 4

我的keras模型如下:

model = Sequential() 
model.add(Flatten())
model.add(Dense(100, activation='relu',input_shape=(50,4)))
model.add(Dropout(0.25))
model.add(Dense(50, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(len(one_hot_encoded_labels), activation='softmax'))

我尝试了以下不同的损失函数

#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.00001), metrics=['accuracy'])
#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.0001), metrics=['mean_absolute_error',r_square])
#model.compile(loss='kullback_leibler_divergence',optimizer=Adam(lr=0.00001), metrics=['categorical_accuracy'])
#model.compile(loss=log_poisson_loss,optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
#model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
model.compile(loss='poisson',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])

损失表现合理;它下降并随着时期的增加而趋于平坦。我尝试了不同的学习率,不同的优化器,每层中不同数量的神经元,不同数量的隐藏层以及不同类型的正则化。

我认为我的模型总是将大多数预测的标签放在测试数据的峰值附近(请参阅链接的直方图),但是它无法对测试集中计数较少的序列进行分类。这是常见问题吗?

如果不使用其他架构(例如卷积或递归),是否有人知道我如何能够改善此模型的分类性能?

Training data file

1 个答案:

答案 0 :(得分:0)

从直方图分布中可以明显看出,您的测试数据集非常不平衡。我假设,您的训练数据分布相同。这可能是神经网络性能差的原因,因为对于许多类来说,神经网络没有太多的数据来学习这些功能。您可以尝试一些采样技术,以便可以比较每个类之间的关系。

这里是link,它解释了这种不平衡数据集的各种方法。

第二,您可以通过交叉验证来检查模型的性能,可以在其中轻松找到该误差是可减少的还是不可减少的。如果那是无法避免的错误,则您无法再进行任何改进(针对这种情况,您必须尝试另一种方法)。

第三,序列之间存在相互关系。简单的前馈网络无法捕获这种关系。 Recurrent-network可以捕获数据集中的此类依存关系。这是简单的example。此示例适用于二进制类,可以根据情况将其扩展为multi-class

对于loss-function选择,这完全是特定于问题的。您可以check this link解释哪些时间以及哪些损失函数会有所帮助。