我有5名输出的类数据。训练数据没有样本的以下这些5类:[706326,32211,2856,3050,901]
我使用以下keras(tf.keras)代码:
class_weights = class_weight.compute_class_weight('balanced',
np.unique(y_train),
y_train)
model = tf.keras.Sequential([
tf.keras.layers.Dense(50, input_shape=(dataX.shape[1],)),
tf.keras.layers.Dropout(rate = 0.5),
tf.keras.layers.Dense(50, activation=tf.nn.relu),
tf.keras.layers.Dropout(rate = 0.5),
tf.keras.layers.Dense(50, activation=tf.nn.relu),
tf.keras.layers.Dropout(rate = 0.5),
tf.keras.layers.Dense(50, activation=tf.nn.relu),
tf.keras.layers.Dropout(rate = 0.5),
tf.keras.layers.Dense(5, activation=tf.nn.softmax) ])
adam = tf.keras.optimizers.Adam(lr=0.5)
model.compile(optimizer=adam,
loss='sparse_categorical_crossentropy',
metrics=[metrics.sparse_categorical_accuracy])
model.fit(X_train,y_train, epochs=5, batch_size=32, class_weight=class_weights)
y_pred = np.argmax(model.predict(X_test), axis=1)
我正在使用sparse_categorical_crossentropy,它接受类别为整数(不需要将它们转换为一键编码),但是我也尝试了categorical_crossentropy,仍然存在相同的问题。
我当然尝试了不同的学习率,的batch_size,无历元,优化器,和深度/网络的长度。但它始终是停留在0.94〜精度基本上是,如果我预测一流所有的时间我会得到。
不确定此处缺少什么。我的一部分的任何错误,或Keras与class_weight一些bug?还是应该使用其他一些专门的深度网络?