有没有办法告诉numpy在写入python pickle文件时保留非标准跨步?
>>> # Create an array with non-standard striding
>>> x = numpy.arange(2*3*4, dtype='uint8').reshape((2,3,4)).transpose(0,2,1)
>>> x.strides
(12, 1, 4)
>>> # The pickling process converts it to a c-contiguous array.
>>> # Often, this is a good thing, but for some applications, the
>>> # non-standard striding is intentional and important to preserve.
>>> pickled = cPickle.dumps(x, protocol=cPickle.HIGHEST_PROTOCOL)
>>> cPickle.loads(pickled).strides
(12, 3, 1)
>>> # This is indeed happening during serialization, not deserialization
>>> pickletools.dis(pickled)
...
151: S STRING '\x00\x04\x08\x01\x05\t\x02\x06\n\x03\x07\x0b\x0c\x10\x14\r\x11\x15\x0e\x12\x16\x0f\x13\x17'
...
注意: numpy足够聪明,可以保留c-contiguous或fortran-contiguous,但它不会保留所有非标准步幅模式,包括酸洗和去除斑点。
答案 0 :(得分:3)
我能想到的唯一方法就是亲自去做:
# ---------------------------------------------
import numpy
from numpy.lib.stride_tricks import as_strided
import cPickle
def dumps(x, protocol=cPickle.HIGHEST_PROTOCOL):
# flatten that keep real data order
y = as_strided(x, shape=(x.size,), strides=(min(x.strides),))
return cPickle.dumps([y,x.shape,x.strides],protocol=protocol)
def loads(pickled):
y,shape,strides = cPickle.loads(pickled)
return as_strided(y,shape=shape,strides=strides)
if __name__=='__main__':
x = numpy.arange(2*3*4, dtype='uint8').reshape((2,3,4)).transpose(0,2,1)
pickled = dumps(x)
y = loads(pickled)
print 'x strides =', x.strides
print 'y strides =', y.strides
print 'x==y:', (x==y).all()
# ---------------------------------------------
输出:
x strides = (12, 1, 4)
y strides = (12, 1, 4)
x==y: True