python将预测结果转换为热门

时间:2016-11-23 23:54:21

标签: python

我运行python神经网络预测,期待一个热门结果并获得如下数字:

[[0,1,0],
[0,0,1],
[0,1,0],
[1,0,0]
[0,1,0]]

如何将数组转换为一个热结果,即

let

3 个答案:

答案 0 :(得分:2)

根据一个热门结果,我假设您希望每个子列表的最大值为1,并且休息为0(基于当前结果中的模式)。您可以使用 list comprehension 执行此操作:

>>> [[int(item == max(sublist)) else 0 for item in sublist] for sublist in my_list]
#      ^  converts bool value returned by `==` into `int`. True -> 1, False -> 0
[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 0]]

其中my_list是您的初始列表。

但是在上面的方法中,每次迭代子列表时都会计算max()。更好的方法就是:

def get_hot_value(my_list):
    max_val = max(my_list)
    return [int(item == max_val) for item in my_list]

hot_list = [get_hot_value(sublist) for sublist in my_list]

修改:如果您应该在列表中只有一个1(如果超过1个元素的最大值),您可以修改get_hot_value功能如:

def get_hot_value(my_list):
    max_val, hot_list, is_max_found = max(my_list), [], False
    for item in my_list:
        if item == max_val and not is_max_found:
            hot_list.append(1)
        else:
            hot_list.append(0)
            is_max_found = True
    return hot_list

答案 1 :(得分:1)

我会建议:

n = [[0.33058667182922363, 0.3436272442340851, 0.3257860243320465],
     [0.32983461022377014, 0.3487854599952698, 0.4213798701763153],
     [0.3311253488063812, 0.3473075330257416, 0.3215670585632324],
     [0.38368630170822144, 0.35151687264442444, 0.3247968554496765],
     [0.3332786560058594, 0.343686580657959, 0.32303473353385925]]

hot_results = []

for row in n:
    hot_index = row.index(max(row))
    hot_result = [0] * len(row)
    hot_result[hot_index] = 1
    hot_results.append(hot_result)

print(hot_results)

答案 2 :(得分:1)

其他解决方案都很好,并解决了问题。或者,如果你有numpy,

import numpy as np
n = [[0.33058667182922363, 0.3436272442340851, 0.3257860243320465],
     [0.32983461022377014, 0.3487854599952698, 0.4213798701763153],
     [0.3311253488063812, 0.3473075330257416, 0.3215670585632324],
     [0.38368630170822144, 0.35151687264442444, 0.3247968554496765],
     [0.3332786560058594, 0.343686580657959, 0.32303473353385925]]

max_indices = np.argmax(n,axis=1)

final_values = [n[i] for i in max_indices]

argmax能够找到该行中最大值的索引,然后你只需要对其进行一次列表理解。应该是相当快的我猜?