数据不平衡:class_weight平衡的前馈深层网络不会超出主导类的学习范围

时间:2019-02-02 05:45:35

标签: tensorflow keras deep-learning multiclass-classification

我有5名输出的类数据。训练数据没有样本的以下这些5类:[706326,32211,2856,3050,901]

我使用以下keras(tf.keras)代码:

class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(y_train),
                                                 y_train)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(50, input_shape=(dataX.shape[1],)),
    tf.keras.layers.Dropout(rate = 0.5),
    tf.keras.layers.Dense(50, activation=tf.nn.relu),
    tf.keras.layers.Dropout(rate = 0.5),
    tf.keras.layers.Dense(50, activation=tf.nn.relu),
    tf.keras.layers.Dropout(rate = 0.5),
    tf.keras.layers.Dense(50, activation=tf.nn.relu),
    tf.keras.layers.Dropout(rate = 0.5),
    tf.keras.layers.Dense(5, activation=tf.nn.softmax) ])
     adam = tf.keras.optimizers.Adam(lr=0.5)

model.compile(optimizer=adam, 
              loss='sparse_categorical_crossentropy',
              metrics=[metrics.sparse_categorical_accuracy])    
     model.fit(X_train,y_train, epochs=5, batch_size=32, class_weight=class_weights)

y_pred = np.argmax(model.predict(X_test), axis=1)

我正在使用sparse_categorical_crossentropy,它接受类别为整数(不需要将它们转换为一键编码),但是我也尝试了categorical_crossentropy,仍然存在相同的问题。

我当然尝试了不同的学习率,的batch_size,无历元,优化器,和深度/网络的长度。但它始终是停留在0.94〜精度基本上是,如果我预测一流所有的时间我会得到。

不确定此处缺少什么。我的一部分的任何错误,或Keras与class_weight一些bug?还是应该使用其他一些专门的深度网络?

0 个答案:

没有答案