如何从mxnet nd数组访问类标签的索引?输出必须是使用功能输入的标签名称的索引

时间:2019-12-02 16:47:50

标签: python arrays function indexing mxnet

#pred_probas is probabilities of each class type
def find_class_idx(label):
    """
    Should return the class index of a particular label.

    :param label: label of class
    :type label: str

    :return: class index
    :rtype: int
    """
    #ind = mx.nd.argmax(label, axis=1).astype('int')
    topk_indices=mx.nd.topk(pred_probas,k=100)
    return max(topk_indices)*100

1 个答案:

答案 0 :(得分:0)

因为GluonCV的网络输出类是列表类型。我们可以使用此功能list.index(label)

访问列表的索引
def find_class_idx(label):
"""
Should return the class index of a particular label.

:param label: label of class
:type label: str

:return: class index
:rtype: int
"""

return network.classes.index(label)