How to use evaluate_generator in keras?

时间:2019-01-07 13:34:53

标签: python-3.x keras deep-learning

When I have trained a model and save the params as hdf5 file, and then I try to evaluate the performance of the model on test_dataset, but there is something wrong in param metrics of model.compile, if I set metrics as

model.compile(optimizer=adam,
              loss=losses.sparse_categorical_crossentropy,
              metrics=[cus_acc, miou])  

It will occurs an Error as follow:

tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [`labels` out of bound] [Condition x < y did not hold element-wise:] [x (metrics/cus_acc/confusion_matrix/control_dependency:0) = ] [107 118 135...] [y (metrics/cus_acc/confusion_matrix/Cast_2:0) = ] [2]  [[Node: metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_INT64], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/Switch/_3827, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/data_0, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/data_1, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/data_2, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/Switch_1/_3829, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/data_4, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/Switch_2/_3831)]]

Notes that optimizer adam have been defined, and cus_acc and miou are customized metrics. If delete the metrics it works, so I think there must be something wrong in it, miou and cus_acc are both calculate by the confusion_matrix.
My Question is what makes the error occurs and how to use evaluate_generator to evaluate model performance in keras, if you can provide example code that the best~
Any help would be appreciated. Thanks in advance. :D

1 个答案:

答案 0 :(得分:0)

已解决。
创建confusion_matrix时发生错误,nb_classes即0、1、2必须与ground-truth 0、1、2中的标签匹配。例如,地面真相未经过预处理,且像素分别为0、127、255。然后InvalidArgumentError: assertion failed就会出现。
使我的代码出错的原因不是由于自定义指标,而是当我为地面真实性创建data_generator时,param directory的设置与images相同。