有没有办法避免这种循环,所以优化代码?
import numpy as np
cLoss = 0
dist_ = np.array([0,1,0,1,1,0,0,1,1,0]) # just an example, longer in reality
TLabels = np.array([-1,1,1,1,1,-1,-1,1,-1,-1]) # just an example, longer in reality
t = float(dist_.size)
for i in range(len(dist_)):
labels = TLabels[dist_ == dist_[i]]
cLoss+= 1 - TLabels[i]*(1. * np.sum(labels)/t)
print cLoss
注意: dist_
和TLabels
都是具有相同形状的numpy数组(t,1)
答案 0 :(得分:2)
我首先想知道,循环中每一步的labels
是什么?
使用dist_ = array([2,1,2])
和TLabels=array([1,2,3])
我得到了
[-1 1]
[1]
[-1 1]
不同的长度立即引发警告标志 - 可能很难对此进行矢量化。
在编辑的示例中使用较长的数组
[-1 1 -1 -1 -1]
[ 1 1 1 1 -1]
[-1 1 -1 -1 -1]
[ 1 1 1 1 -1]
[ 1 1 1 1 -1]
[-1 1 -1 -1 -1]
[-1 1 -1 -1 -1]
[ 1 1 1 1 -1]
[ 1 1 1 1 -1]
[-1 1 -1 -1 -1]
labels
向量的长度都相同。这是正常的,还是仅仅是值的巧合?
从dist_
删除几个元素,labels
为:
In [375]: for i in range(len(dist_)):
labels = TLabels[dist_ == dist_[i]]
v = (1.*np.sum(labels)/t); v1 = 1-TLabels[i]*v
print(labels, v, TLabels[i], v1)
cLoss += v1
.....:
(array([-1, 1, -1, -1]), -0.25, -1, 0.75)
(array([1, 1, 1, 1]), 0.5, 1, 0.5)
(array([-1, 1, -1, -1]), -0.25, 1, 1.25)
(array([1, 1, 1, 1]), 0.5, 1, 0.5)
(array([1, 1, 1, 1]), 0.5, 1, 0.5)
(array([-1, 1, -1, -1]), -0.25, -1, 0.75)
(array([-1, 1, -1, -1]), -0.25, -1, 0.75)
(array([1, 1, 1, 1]), 0.5, 1, 0.5)
再次标注不同长度,但实际上只有少量计算。每个不同的v
值都有1 dist_
个值。
如果不计算所有细节,看起来您只是为每个不同的labels*labels
值计算dist_
,然后将它们相加。
这看起来像groupBy
问题。您希望将dist_
划分为具有公共值的组,并将其对应的TLabels
值的某些函数求和。 Python itertools
具有groupBy
函数,pandas
也是如此。我认为两者都要求你排序dist_
。
尝试排序dist_
并查看是否可以增加问题的清晰度。
答案 1 :(得分:2)
我不确定你究竟想做什么,但是你知道scipy.ndimage.measurements
计算带有标签的数组吗?它看起来像你想要的东西:
cLoss = len(dist_) - sum(TLabels * scipy.ndimage.measurements.sum(TLabels,dist_,dist_) / len(dist_))
答案 2 :(得分:1)
我不确定这是否更好,因为我并不完全理解你为什么要这样做。循环中的许多变量都是双向的,因此可以提前计算。
dist_
的条目也可以用作布尔开关,但无论如何我都使用了显式副本。
dist_ = np.array([0,1,0,1,1,0,0,1,1,0])
TLabels = np.array([-1,1,1,1,1,-1,-1,1,-1,-1])
t = len(dist)
dist_zeros = dist_== 0
one_zero_sum = [sum(TLabels[dist_zeros])/t , sum(TLabels[~dist_zeros])/t]
cLoss = sum([1-x*one_zero_sum[dist_[y]] for y,x in enumerate(TLabels)])
会产生cLoss = 8.2
。我正在使用Python3,所以没有在Python2中检查这是否是真正的除法。