截断字符串时numpy ndarray抛出异常

时间:2016-07-26 22:41:39

标签: python-3.x numpy

我有一个不同长度的ndarray ascii字符串。到目前为止,我使用了dtype=object。然而,分析显示这实际上是我的程序中的瓶颈。使用dtype=np.string_更快,但它有一个缺点,它默默地截断设置值。由于这是一个很难找到错误的完美配方,我想知道是否有可能重新缩放(我知道在整个重新分配的情况下这可能是昂贵的)数组或在截断的情况下引发异常?

我无法更改ndarray.__setitem__,因为它是一个只读属性。这是一些代码来证明我的意思:

import numpy as np


def Foo(vec):
    vec[1] = 'FAIL'

    print('{:6s}: {}'.format(str(vec.dtype), vec))


VALUES = ['OK', 'OK', 'OK']

Foo(np.array(VALUES, dtype=object)) # Slow but it works
Foo(np.array(VALUES, dtype=np.string_)) # Fast but may fail silently

导致:

object: ['OK' 'FAIL' 'OK']
|S2   : [b'OK' b'FA' b'OK']

2 个答案:

答案 0 :(得分:1)

让我们看看我能解释发生了什么

In [32]: ll=['one','two','three']
In [33]: a1=np.array(ll,dtype=object)
In [34]: a1
Out[34]: array(['one', 'two', 'three'], dtype=object)
In [35]: a1[1]='eleven'
In [36]: a1
Out[36]: array(['one', 'eleven', 'three'], dtype=object)

a1就像ll一样由指针组成 - 指向驻留在内存中其他位置的字符串的指针。我可以更改任何指针,就像我在列表中一样。在大多数情况下,a1的行为就像一个列表 - 除了可以重塑,并做一些其他基本的array事情。

In [37]: a1.reshape(3,1)
Out[37]: 
array([['one'],
       ['eleven'],
       ['three']], dtype=object)

但是如果我制作一个string数组

In [38]: a2=np.array(ll)
In [39]: a2
Out[39]: 
array(['one', 'two', 'three'], 
      dtype='<U5')
In [42]: a1.itemsize
Out[42]: 4
In [43]: a2.itemsize
Out[43]: 20

值存储在数组的数据缓冲区中。这里它创建了一个数组,每个元素有5个unicode字符(Python3)(每个5 * 4字节)。

现在,如果我替换a2的元素,我可以截断

In [44]: a2[1]='eleven'
In [45]: a2
Out[45]: 
array(['one', 'eleve', 'three'], 
      dtype='<U5')

因为新值中只有5个字符符合分配的空间。

所以有一个权衡 - 更快的访问速度,因为字节存储在一个固定的,已知大小的数组中,但你不能存储更大的东西。

您可以为每个元素分配更多空间:

In [46]: a3=np.array(ll,dtype='|U10')
In [47]: a3
Out[47]: 
array(['one', 'two', 'three'], 
      dtype='<U10')
In [48]: a3[1]='eleven'
In [49]: a3
Out[49]: 
array(['one', 'eleven', 'three'], 
      dtype='<U10')

genfromtxt是使用字符串dtypes创建数组的常用工具。等到它在设置字符串长度之前已经读取了所有文件(至少如果使用dtype=None)。字符串字段通常是多字段结构化数组的一部分。字符串字段通常是标签或ID,而不是您经常更改的内容。

我可以想象编写一个函数来检查字符串长度与dtype,并在发生截断时引发错误。但这会减慢行动速度。

def foo(A, i, astr):
    if A.itemsize/4<len(astr):
        raise ValueError('too long str')
    A[i] = astr

In [69]: foo(a2,1,'four')
In [70]: a2
Out[70]: 
array(['one', 'four', 'three'], 
      dtype='<U5')
In [72]: foo(a2,1,'eleven')
...
ValueError: too long str

但值得额外的工作吗?

答案 1 :(得分:0)

我通过继承ndarray找到了一个非灵活的解决方案。我不会接受这个答案,直到星期五也许有人想出更好的东西。它履行其职责,甚至在视图上(例如StringArray(...)[1:4])

import numpy as np

class StringArray(np.ndarray):
    def __new__(cls, val):
        field_length = max(map(len, val))
        # Could also be <U for unicode
        vec = super().__new__(cls, len(val), dtype='|S' + str(field_length))
        vec[:] = val[:]
        return vec

    def __setitem__(self, key, val):
        if isinstance(val, (list, tuple, nd.array)):
            if max(map(len, val)) > self.dtype.itemsize:
                raise ValueError('Itemsize too big')
        elif isinstance(val, str):
            if len(val) > self.dtype.itemsize:
                raise ValueError('Itemsize too big')
        else:
            raise ValueError('Unknown type')
        super().__setitem__(key, val)


val = StringArray(['a', 'ab', 'abc'])
print(val)
val[0] = 'xy'
print(val)
try:
    val[0] = 'xyze'
except ValueError:
    print('Catch')

try:
    val[1:2] = ['xyze', 'sd']
except ValueError:
    print('Catch')

产生

[b'a' b'ab' b'abc']
[b'xy' b'ab' b'abc']
Catch
Catch