我在tensorflow中有一个二进制分类任务。我已经使用tf.keras.Sequential类构建了3D-CNN,并将其命名为Net_1
。
我有一个不平衡训练数据集:0类有171个阴性样本,1类有84个阳性样本。
现在,给定epochs = 20
,我正在通过以下方式急于执行来训练模型:
#... code where I define loss_function, grad_function and optimizer...
for epoch in range(1, epochs): # make epochs start from 1
# define metrics
epoch_loss_avg = tf.keras.metrics.Mean(name='train_mean_loss')
epoch_bin_accuracy = tf.keras.metrics.BinaryAccuracy(name="train_binary_accuracy")
for _train_patches, train_labels in batched_train_dataset:
# Optimize the model
loss_value, grads = grad(Net_1, _train_patches, train_labels) # invoke function that computes the gradients
adam_optimizer.apply_gradients(zip(grads, Net_1.trainable_variables), global_step=global_step) # update model's parameters through the optimizer
net_prediction = Net_1(_train_patches) # compute model's output
# Track progress
epoch_loss_avg.update_state(loss_value) # accumulate loss metric over current batch
epoch_bin_accuracy.update_state(y_true=train_labels, y_pred=net_prediction) # accumulate accuracy metric over current batch
# end epoch: print loss and acc
print("Epoch {:03d}: Loss: {:.3f}, Bin_Accuracy: {:.3%}".format(epoch, epoch_loss_avg.result(), epoch_bin_accuracy.result()))
# reset training metrics
epoch_loss_avg.reset_states()
epoch_bin_accuracy.reset_states()
即使有时训练有效(损耗减少到零且准确性提高),但有时模型会陷入困境并始终输出多数类(在这种情况下为0),因此损失和准确性也会陷入困境。
如何在急切执行中解决?在诉诸少数类过采样之前,我正在寻找与class_weight
方法的fit
属性等效的内容。
谢谢!