Tensorflow Precision,Recall,F1 - 多标签分类

时间:2017-07-24 18:02:42

标签: python machine-learning tensorflow deep-learning

我正在尝试使用tensorflow实现多标签句子分类模型。大约有1500个标签。 该模型工作得很好,但我不确定它生成的指标。

这是生成指标的代码段:

    with tf.name_scope('loss'):
        losses = tf.nn.softmax_cross_entropy_with_logits(labels=self.input_y, logits=self.scores) #  only named arguments accepted
        self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss

    with tf.name_scope('accuracy'):
        correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name='accuracy')

    with tf.name_scope('num_correct'):
        correct = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
        self.num_correct = tf.reduce_sum(tf.cast(correct, 'float'))

    with tf.name_scope('fp'):
        fp = tf.metrics.false_positives(labels=tf.argmax(self.input_y, 1), predictions=self.predictions)
        self.fp = tf.reduce_sum(tf.cast(fp, 'float'), name='fp')

    with tf.name_scope('fn'):
        fn = tf.metrics.false_negatives(labels=tf.argmax(self.input_y, 1), predictions=self.predictions)
        self.fn = tf.reduce_sum(tf.cast(fn, 'float'), name='fn')

    with tf.name_scope('recall'):
        self.recall = self.num_correct / (self.num_correct + self.fn)

    with tf.name_scope('precision'):
        self.precision = self.num_correct / (self.num_correct + self.fp)

    with tf.name_scope('F1'):
        self.F1 = (2 * self.precision * self.recall) / (self.precision + self.recall)

    with tf.name_scope('merged_summary'):
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("accuracy", self.accuracy)
        tf.summary.scalar("recall", self.recall)
        tf.summary.scalar("precision", self.precision)
        tf.summary.scalar("f-measure", self.F1)
        self.merged_summary = tf.summary.merge_all()

然后,在火车部分,我为Tensorboard创建了保护程序:

summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())

最后,培训会将指标保存如下:

for train_batch in train_batches:
            x_train_batch, y_train_batch = zip(*train_batch)
            train_step(x_train_batch, y_train_batch)
            current_step = tf.train.global_step(sess, global_step)

            # Evaluate the model with x_dev and y_dev
            if current_step % params['evaluate_every'] == 0:
                dev_batches = data_helper.batch_iter(list(zip(x_dev, y_dev)), params['batch_size'], 1)

                total_dev_correct = 0
                for dev_batch in dev_batches:
                    x_dev_batch, y_dev_batch = zip(*dev_batch)
                    acc, loss, num_dev_correct, predictions, recall, precision, f1, summary = dev_step(x_dev_batch, y_dev_batch)
                    total_dev_correct += num_dev_correct
                accuracy = float(total_dev_correct) / len(y_dev)
                logging.info('Accuracy on dev set: {}'.format(accuracy))
                # added loss
                logging.info('Loss on dev set: {}'.format(loss))
                # adding more measures
                logging.info('Recall on dev set: {}'.format(recall))
                logging.info('Precision on dev set: {}'.format(precision))
                logging.info('F1 on dev set: {}'.format(f1))
                summary_writer.add_summary(summary, current_step)

                if accuracy >= best_accuracy:
                    best_accuracy, best_loss, best_at_step, best_recall, best_precision, best_f1 = accuracy, loss, current_step, recall, precision, f1
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logging.critical('Saved model {} at step {}'.format(path, best_at_step))
                    logging.critical('Best accuracy {} at step {}'.format(best_accuracy, best_at_step))
                    logging.critical('Best loss {} at step {}'.format(best_loss, best_at_step))
                    logging.critical('Best recall {} at step {}'.format(best_recall, best_at_step))
                    logging.critical('Best precision {} at step {}'.format(best_precision, best_at_step))
                    logging.critical('Best F1 {} at step {}'.format(best_f1, best_at_step))
        logging.critical('Training is complete, testing the best model on x_test and y_test')

dev_step和train_step如下所示:

def train_step(x_batch, y_batch):
            feed_dict = {
                cnn_rnn.input_x: x_batch,
                cnn_rnn.input_y: y_batch,
                cnn_rnn.dropout_keep_prob: params['dropout_keep_prob'],
                cnn_rnn.batch_size: len(x_batch),
                cnn_rnn.pad: np.zeros([len(x_batch), 1, params['embedding_dim'], 1]),
                cnn_rnn.real_len: real_len(x_batch),
            }
            _, step, loss, accuracy = sess.run([train_op, global_step, cnn_rnn.loss, cnn_rnn.accuracy], feed_dict)

        def dev_step(x_batch, y_batch):
            feed_dict = {
                cnn_rnn.input_x: x_batch,
                cnn_rnn.input_y: y_batch,
                cnn_rnn.dropout_keep_prob: 1.0,
                cnn_rnn.batch_size: len(x_batch),
                cnn_rnn.pad: np.zeros([len(x_batch), 1, params['embedding_dim'], 1]),
                cnn_rnn.real_len: real_len(x_batch),
            }
            step, loss, accuracy, num_correct, predictions, recall, precision, f1, summary = sess.run(
                [global_step, cnn_rnn.loss, cnn_rnn.accuracy, cnn_rnn.num_correct, cnn_rnn.predictions, cnn_rnn.recall, cnn_rnn.precision, cnn_rnn.F1, cnn_rnn.merged_summary], feed_dict)
            return accuracy, loss, num_correct, predictions, recall, precision, f1, summary

我的问题是,是否针对多标签分类问题正确生成了指标,还是应该通过混淆矩阵来实现? 如果我应该使用混淆矩阵,我应该添加:

tf.confusion_matrix(labels=, predictions=)

在我声明指标的代码的第一部分?如果是,我接下来应该做些什么来获得精确度和召回。

编辑:我已添加此内容,但张量板中的图片只是一个黑屏。

batch_confusion = tf.confusion_matrix(labels=tf.argmax(self.input_y, 1), predictions=self.predictions, name='batch_confusion', num_classes=num_classes)
            confusion = tf.Variable(tf.zeros([num_classes, num_classes], dtype=tf.int32), name='confusion')
            confusion_image = tf.reshape(tf.cast(confusion, tf.float32), [1, num_classes, num_classes, 1])
            tf.summary.image('confusion', confusion_image)

感谢您的帮助,

0 个答案:

没有答案