如何使用自定义生成器生成混淆矩阵

时间:2019-11-24 10:53:13

标签: keras scikit-learn generator

我已经实现了自己的自定义数据生成器,并且正在使用model.predict_generator(generator = testing_generator, steps=steps,verbose=0)来生成预测。为了生成混淆图,我正在使用:

conf_mat = confusion_matrix(testing_generator.classes, y_predict,labels=classes)

我遇到错误

  

“发电机”对象没有属性“类”

这显然是因为我没有实现任何方法来从生成器返回所有GT。我不确定如何实现这种事情?

我的生成器的一般形式为:

    def generator_interference(self,number_of_steps_per_batch):   

        for i in range(number_of_steps_per_batch):
            """
            Some code to generates batches of samples
            :return: 
            """
            yield features, labels

如何检索GT以创建混淆矩阵?我可以简单地添加一个名为self.classes的变量并累积标签吗?如果是这样,如何在预测之间清除它?

0 个答案:

没有答案