When I have trained a model and save the params as hdf5 file, and then I try to evaluate the performance of the model on test_dataset
, but there is something wrong in param metrics
of model.compile
, if I set metrics
as
model.compile(optimizer=adam,
loss=losses.sparse_categorical_crossentropy,
metrics=[cus_acc, miou])
It will occurs an Error as follow:
tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [`labels` out of bound] [Condition x < y did not hold element-wise:] [x (metrics/cus_acc/confusion_matrix/control_dependency:0) = ] [107 118 135...] [y (metrics/cus_acc/confusion_matrix/Cast_2:0) = ] [2] [[Node: metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_INT64], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/Switch/_3827, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/data_0, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/data_1, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/data_2, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/Switch_1/_3829, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/data_4, metrics/cus_acc/confusion_matrix/assert_less/Assert/AssertGuard/Assert/Switch_2/_3831)]]
Notes that optimizer adam
have been defined, and cus_acc
and miou
are customized metrics. If delete the metrics it works, so I think there must be something wrong in it, miou
and cus_acc
are both calculate by the confusion_matrix
.
My Question is what makes the error occurs and how to use evaluate_generator to evaluate model performance in keras, if you can provide example code that the best~
Any help would be appreciated. Thanks in advance. :D
答案 0 :(得分:0)
已解决。
创建confusion_matrix时发生错误,nb_classes
即0、1、2必须与ground-truth
0、1、2中的标签匹配。例如,地面真相未经过预处理,且像素分别为0、127、255。然后InvalidArgumentError: assertion failed
就会出现。
使我的代码出错的原因不是由于自定义指标,而是当我为地面真实性创建data_generator时,param directory
的设置与images
相同。