在急切的执行训练中是否等同于“ class_weight”参数?

时间:2019-09-04 10:12:31

标签: python python-3.x tensorflow keras eager-execution

我在tensorflow中有一个二进制分类任务。我已经使用tf.keras.Sequential类构建了3D-CNN,并将其命名为Net_1

我有一个不平衡训练数据集:0类有171个阴性样本,1类有84个阳性样本。

现在,给定epochs = 20,我正在通过以下方式急于执行来训练模型:

#... code where I define loss_function, grad_function and optimizer...

for epoch in range(1, epochs):  # make epochs start from 1

    # define metrics
    epoch_loss_avg = tf.keras.metrics.Mean(name='train_mean_loss')
    epoch_bin_accuracy = tf.keras.metrics.BinaryAccuracy(name="train_binary_accuracy")

    for _train_patches, train_labels in batched_train_dataset:
        # Optimize the model
        loss_value, grads = grad(Net_1, _train_patches, train_labels)  # invoke function that computes the gradients
        adam_optimizer.apply_gradients(zip(grads, Net_1.trainable_variables), global_step=global_step)  # update model's parameters through the optimizer

        net_prediction = Net_1(_train_patches)  # compute model's output

        # Track progress
        epoch_loss_avg.update_state(loss_value)  # accumulate loss metric over current batch
        epoch_bin_accuracy.update_state(y_true=train_labels, y_pred=net_prediction)  # accumulate accuracy metric over current batch

    # end epoch: print loss and acc
    print("Epoch {:03d}: Loss: {:.3f}, Bin_Accuracy: {:.3%}".format(epoch, epoch_loss_avg.result(), epoch_bin_accuracy.result()))

    # reset training metrics
    epoch_loss_avg.reset_states()
    epoch_bin_accuracy.reset_states()

即使有时训练有效(损耗减少到零且准确性提高),但有时模型会陷入困境并始终输出多数类(在这种情况下为0),因此损失和准确性也会陷入困境。

如何在急切执行中解决?在诉诸少数类过采样之前,我正在寻找与class_weight方法的fit属性等效的内容。

谢谢!

0 个答案:

没有答案