在numpy数组

时间:2017-09-07 16:41:50

标签: python arrays numpy

我正在研究在numpy数组中存储点是否有助于我搜索点数,我有几个问题。

我有一个Point类,代表一个三维点。

class Point( object ):
  def __init__( self, x, y, z ):
    self.x = x
    self.y = y
    self.z = z

  def __repr__( self ):
    return "<Point (%r, %r, %r)>" % ( self.x, self.y, self.z )

我构建了一个Point对象列表。请注意,坐标(1, 2, 3)故意发生两次;这就是我要搜索的内容。

>>> points = [Point(1, 2, 3), Point(4, 5, 6), Point(1, 2, 3), Point(7, 8, 9)]

我将Point对象存储在numpy数组中。

>>> import numpy
>>> npoints = numpy.array( points )
>>> npoints
array([<Point (1, 2, 3)>, <Point (4, 5, 6)>, <Point (1, 2, 3)>,
   <Point (7, 8, 9)>], dtype=object)

我按以下方式搜索坐标为(1, 2, 3)的所有点。

>>> numpy.where( npoints == Point(1, 2, 3) )
>>> (array([], dtype=int64),)

但是,结果没有用。所以,这似乎不是正确的方法。是numpy.where要使用的东西吗?是否有其他方式来表达numpy.where成功的条件?

我接下来要尝试的是只存储numpy数组中点的坐标。

>>> npoints = numpy.array( [(p.x, p.y, p.z) for p in points ])
>>> npoints
array([[1, 2, 3],
      [4, 5, 6],
      [1, 2, 3],
      [7, 8, 9]])

我按以下方式搜索坐标为(1,2,3)的所有点。

>>> numpy.where( npoints == [1,2,3] )
(array([0, 0, 0, 2, 2, 2]), array([0, 1, 2, 0, 1, 2]))

结果至少是我可以处理的事情。第一个返回值array([0, 0, 0, 2, 2, 2])中的行索引数组确实告诉我,我正在搜索的坐标位于npoints的第0行和第2行。我可以做出类似以下的事情。

>>> rows, cols = numpy.where( npoints == [1,2,3] )
>>> rows
array([0, 0, 0, 2, 2, 2])
>>> cols
array([0, 1, 2, 0, 1, 2])
>>> foundRows = set( rows )
>>> foundRows
set([0, 2])
>>> for r in foundRows:
...   # Do something with npoints[r]

然而,我觉得我并没有恰当地使用numpy.where,而且我在这种特殊情况下才刚刚好运。

在numpy数组中找到所有n维点(即具有特定值的行)的适当方法是什么?

保留数组的顺序至关重要。

1 个答案:

答案 0 :(得分:0)

您可以在object.__eq__(self, other)课程中创建“丰富的比较”方法Point,以便能够在==个对象中使用Point

class Point( object ):
  def __init__( self, x, y, z ):
    self.x = x
    self.y = y
    self.z = z

  def __repr__( self ):
    return "<Point (%r, %r, %r)>" % ( self.x, self.y, self.z )
  def __eq__(self, other):
    return self.x == other.x and self.y == other.y and self.z == other.z

import numpy
points = [Point(1, 2, 3), Point(4, 5, 6), Point(1, 2, 3), Point(7, 8, 9)]
npoints = numpy.array( points )
found = numpy.where(npoints == Point(1, 2, 3))
print(found) # => (array([0, 2]),)