使用get_preds
之后,我会得到张量列表。像这样:
[
[0.9, 0.1],
[0.85, 0.15],
[0.92, 0.08],
...
]
我的班级是[0,1]。我应该如何将这些张量转换为相应的(最可能的)类?请在下面查看我当前的方法
probs = learn.get_preds(ds_type=DatasetType.Test)[0]
def probs2class(item):
return max(range(len(item)), key=item.__getitem__)
print(map(probs2class, probs))
我疯狂地搜索了文档,但也许是错误的词?从概率到班级预测的一般方法是什么?