来自张量数据的类别混淆矩阵计算

时间:2021-02-06 19:28:50

标签: tensorflow machine-learning image-processing deep-learning tensor

我正在尝试从中计算类别混淆矩阵和以后的精度。以下是我的功能:

def calc_confusion(flat_labels, flat_logits, n_classes, loss_mask=None, name = 'calc_confusion'):
    with tf.compat.v1.variable_scope(name):
        metric = tfa.metrics.MultiLabelConfusionMatrix(num_classes=n_classes)
        metric.update_state(flat_labels, flat_logits)
        conf_matrix = metric.result()
        pr,re,sp,ac,fs, dice={},{},{},{},{},{}
        tns = conf_matrix[:,0,0]
        fps = conf_matrix[:,0,1]
        fns = conf_matrix[:,1,0]
        tps = conf_matrix[:,1,1]
        for c in range(tps.shape[0]):  #here tps.shape[0] represent no of classes
            pr[c]=tf.divide(tps[c],tf.add(tps[c],fps[c]))
        return {'precision': pr} 

这里 flat_labels、flat_logits 是张量输入。

主要问题是::我可以使用列表变量返回所有类精度值并在 tf.session 中运行该变量吗?

以下是错误:

(0) 失败的前提条件:从容器读取资源变量 prob_unet/calc_confusion/false_positives 时出错:localhost。这可能意味着该变量未初始化。未找到:资源 localhost/prob_unet/calc_confusion/false_positives/N10tensorflow3VarE 不存在。 [[节点 prob_unet/calc_confusion/packed/ReadVariableOp_1(定义于 /lib/python3.6/site-packages/tensorflow_addons/metrics/multilabel_confusion_matrix.py:158)]] [[prob_unet/prob_unet/conv_decoder/conv2_d_3/BiasAdd/_3349]]

有关更多详细信息:我通过“probabilistic_unet.py”文件在“train.py”中调用上述函数。此处提供的文件链接:https://github.com/sandeepsinghsengar/confusion

张量变量初始化(会话/图形分配)的某个地方存在问题。

让我知道更多信息。只关注混淆矩阵和精度(忽略其他)。

0 个答案:

没有答案