具有多个元素的数组的Pyplot真值是不明确的

时间:2017-06-03 15:09:03

标签: python matplotlib knn density-plot

我正在努力实施一项1D估算:

std::set<FxDeal> deal_set{deals_all.cbegin(), deals_all.cend()};

for (auto const& d : deals_new) {
  auto it = deal_set.find(d);
  if (it != deal_set.end()) {
     FxDeal x = *it;
     // update x with d.ccy_pair_ and d.amount_;
     // this allows adding amounts, for e.g. x.amount_ += d.amount_
     deal_set.erase(it);
     deal_set.insert(x);         
  }
  else {
    deal_set.insert(d);
  }
}

deals_all.assign(deal_set.cbegin(), deal_set.cend());

引发:

# nearest neighbors estimate
def nearest_n(x, k, data):
    # Order dataset
    #data = np.sort(data, kind='mergesort')
    nnb = []
    # iterate over all data and get k nearest neighbours around x
    for n in data:
        if nnb.__len__()<k:
            nnb.append(n)
        else:
            for nb in np.arange(0,k):
                if np.abs(x-n) < np.abs(x-nnb[nb]):
                    nnb[nb] = n
                    break

    nnb = np.array(nnb)
    # get volume(distance) v of k nearest neighbours around x
    v = nnb.max() - nnb.min()
    v = k/(data.__len__()*v)

    return v

interval = np.arange(-4.0, 8.0, 0.1)
plt.figure()
for k in (2,8,35):
    plt.plot(interval, nearest_n(interval, k,train_data), label=str(o))
plt.legend()
plt.show()

我知道错误来自plot()中的数组输入,但我不知道如何在带有运算符&gt; / == /&lt;

的函数中避免这种情况

&#39;数据&#39;来自包含浮点数的1D txt文件。

我尝试使用vectorize:

  File "x", line 55, in nearest_n
    if np.abs(x-n) < np.abs(x-nnb[nb]):
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

导致:

nearest_n = np.vectorize(nearest_n)

这是一个例子,让我们说:

line 50, in nearest_n
    for n in data:
TypeError: 'numpy.float64' object is not iterable

nearest_n(1.5)应该导致

data = [0.5,1.7,2.3,1.2,0.2,2.2]
k = 2

并返回2 /(6 * 0.5)= 2/3

该函数运行例如neares_n(2.0,4,data)并给出0.0741586011463

2 个答案:

答案 0 :(得分:0)

您传递np.arange(-4, 8, .01)作为x,这是一组值。因此x - n是一个与x长度相同的数组,在本例中为120个元素,因为数组的减法和标量会进行逐元素减法。与nnb[nb]相同。因此,比较的结果是一个120长度的数组,其布尔值取决于np.abs(x-n)的每个元素是否小于np.abs(x-nnb[nb])的对应元素。这不能直接用作条件,你需要将这些值合并为一个布尔值(使用all()any(),或者只是重新思考你的代码。)

答案 1 :(得分:0)

plt.figure()
X = np.arange(-4.0,8.0,0.1)
for k in [2,8,35]:
    Y = []
    for n in X:
        Y.append(nearest_n(n,k,train_data))
    plt.plot(X,Y,label=str(k))
plt.show()

工作正常。我以为pyplot.plot已经为我做了这件事,但我猜它不会......