#pred_probas is probabilities of each class type
def find_class_idx(label):
"""
Should return the class index of a particular label.
:param label: label of class
:type label: str
:return: class index
:rtype: int
"""
#ind = mx.nd.argmax(label, axis=1).astype('int')
topk_indices=mx.nd.topk(pred_probas,k=100)
return max(topk_indices)*100
答案 0 :(得分:0)
因为GluonCV的网络输出类是列表类型。我们可以使用此功能list.index(label)
访问列表的索引def find_class_idx(label):
"""
Should return the class index of a particular label.
:param label: label of class
:type label: str
:return: class index
:rtype: int
"""
return network.classes.index(label)