for循环评估准确性不会执行

时间:2016-10-02 16:25:45

标签: python numpy

所以我有以下的numpy数组。

  • X验证集,X_val:(47151,32,32,1)
  • y验证集(标签),y_val_dummy:(47151,5,10)
  • y验证预测集,y_pred:(47151,5,10)

当我运行代码时,似乎需要永远。有人可以建议为什么?我相信它是代码效率问题。我似乎无法完成这个过程。

y_pred_list = model.predict(X_val)
correct_preds = 0
# Iterate over sample dimension
for i in range(X_val.shape[0]):         
    pred_list_i = [y_pred_array[i] for y_pred in y_pred_array]
    val_list_i  = [y_val_dummy[i] for y_val in y_val_dummy]
    matching_preds = [pred.argmax(-1) == val.argmax(-1) for pred, val in zip(pred_list_i, val_list_i)]
    correct_preds = int(np.all(matching_preds))

total_acc = correct_preds / float(x_val.shape[0])

1 个答案:

答案 0 :(得分:0)

你的主要问题是,你没有真正的理由生成大量非常大的列表

for i in range(X_val.shape[0]):         
    pred_list_i = [y_pred for y_pred in y_pred_array[i]]
    val_list_i  = [y_val for y_val in y_val_dummy[i]]
    matching_preds = [pred.argmax(-1) == val.argmax(-1) for pred, val in zip(pred_list_i, val_list_i)]
    correct_preds = int(np.all(matching_preds))

正在发生的事情是迭代nd numpy数组迭代最慢的变化索引(即最左边的),因此每个列表理解都在47K条目上运行。

更好的是

correct_preds = 0.0
for pred, val in zip(y_pred_array, y_val_dummy):
    correct_preds += all(p.argmax(-1) == v.argmax(-1) 
                         for p, v in zip(pred, val))
total_accuracy = correct_preds / x_val.shape[0]

但是你仍然在复制很多阵列而不是真正的目的。以下代码应该做同样的事情,没有无用的复制。

np.argmax

这假定您的正确预测标准是准确的。 您可以通过对{{1}}的几次调用完全避免显式循环,但您必须自己解决这个问题。