如何在Tensorflow中没有显式当前会话的情况下评估张量?

时间:2019-07-25 12:25:16

标签: python tensorflow text-classification multilabel-classification

我正在实现以下示例:https://towardsdatascience.com/building-a-multi-label-text-classifier-using-bert-and-tensorflow-f188e0ecdc5d,该示例使用Tensorflow来微调BERT模型。

我想评估一些张量。现在,我知道评估张量的正确方法是调用tensor.eval(),并在必要时传递当前会话:tensor.eval(session=sess)

现在,在代码中,特别是在下面的代码中,我想评估一些张量。

elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, probabilities, is_real_example):

                logits_split = tf.split(probabilities, num_labels, axis=-1)
                label_ids_split = tf.split(label_ids, num_labels, axis=-1)
                # metrics change to auc of every class
                eval_dict = {}
                for j, logits in enumerate(logits_split):
                    label_id_ = tf.cast(label_ids_split[j], dtype=tf.int32)
                    current_auc, update_op_auc = tf.metrics.auc(label_id_, logits)
                    eval_dict[str(j)] = (current_auc, update_op_auc)
                eval_dict['eval_loss'] = tf.metrics.mean(values=per_example_loss)
                return eval_dict

            eval_metrics = metric_fn(per_example_loss, label_ids, probabilities, is_real_example)

我希望评估label_idsprobabilities,以便将它们传递给其他函数。这些是张量(?,6)。

但是,我对Tensorflow不够熟悉,无法找到我当前的会话在哪里。我以为它是与恢复模型一起提供的,但是因为它使用了BERT代码,所以我有些茫然。如果尝试致电label_ids.eval(),我会收到以下错误消息:

ValueError: Cannot evaluate tensor using `eval()`: No default session is registered. Use `with sess.as_default()` or pass an explicit session to `eval(session=sess)`

如果我使用sess = tf.get_default_session(),则会出现以下错误:

AttributeError: 'function' object has no attribute 'graph'

我如何评估这些张量?

0 个答案:

没有答案