从一站式展示到标签

时间:2019-01-27 20:27:49

标签: python numpy tensor threshold

我的预测在张量pred下,并且pred.shape(4254, 10, 3)。因此,我们有维度为4254的{​​{1}}个矩阵。让我们看一看这些矩阵。

(10, 3)

如上例所示,有10个向量代表标签的一键式表示。例如W = array([[0.04592975, 0.09632163, 0.85774857], [0.03408821, 0.27141285, 0.6944989 ], [0.02538731, 0.4691383 , 0.50547445], [0.01959289, 0.6456455 , 0.33476162], [0.01333424, 0.7494791 , 0.23718661], [0.0109237 , 0.77042925, 0.218647 ], [0.01438793, 0.7796771 , 0.20593494], [0.01474626, 0.6817438 , 0.30350992], [0.02189695, 0.57687664, 0.40122634], [0.03810155, 0.5130332 , 0.44886518]], dtype=float32)

为什么我要按批处理10个向量?我正在研究时间序列预测问题,其中在时间np.argmax([0.04592975, 0.09632163, 0.85774857]) = 2到时间t_0到时间t_1的下10个标签进行预测。

对于这些矩阵中的每一个,我都会想找回原始标签。因此对于矩阵t_10,我应该得到数组W

让我们定义阈值数组array([2, 2, 2, 1, 1, 1, 1, 1, 1, 1])并取回threshold_array = np.array([0.6, 0.65, 0.70, 0.75, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80])。假设中立位置为labels = array([2, 2, 2, 1, 1, 1, 1, 1, 1, 1]),动作为10。此处的目的是根据2和矩阵labels修改threshold_array

如果我拿W,我们知道W[0]np.argmax(W[0]) = 2。与W[0][2] = 0.85774857一样,然后W[0][2] >= threshold_array[0]将保留labels[0]

另一个例子有点不同。如果我采用2,我们知道W[2]np.argmax(W[2]) = 2。与W[2][2] = 0.50547445一样,然后W[2][2] < threshold_array[2]labels[2]变为2

如果我将该策略应用于来自0的每个矢量,则W现在设置为labels。请注意,只有一个动作可以变成中间位置,而不是相反的位置。

如何在python中为array([2, 2, 0, 1, 1, 1, 1, 1, 1, 1])内的每个矩阵W编码该策略以获得维度pred的标签矩阵?

1 个答案:

答案 0 :(得分:0)

我不确定这是否是解决该问题的最佳方法,但这是一个答案。

import numpy as np

threshold_array = np.array([0.6, 0.65, 0.70, 0.75, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80])   

def get_labels(W, threshold_array):

    labels = []
    for i, vect in enumerate(W):
        neutral_position = 1
        label = np.argmax(vect)
        if label in [0, 2]:
            if vect[label] < threshold_array[i]:
                labels.append(neutral_position)
            else:
                labels.append(label)
        else:
            labels.append(label)
    return np.array(labels)

if __name__ == "__main__":
    labels = []
    for matrix in pred:
        labels.append(get_labels(matrix, theshold_array))
    labels = np.array(labels)