加速迭代Numpy Arrays

时间:2011-07-11 17:58:51

标签: python for-loop numpy gdal

我正在使用Numpy进行图像处理,特别是运行标准偏差拉伸。这读取X列数,找到Std。并执行百分比线性拉伸。然后它迭代到列的下一个“组”并执行相同的操作。输入图像是1GB,32位单波段光栅,处理(小时)需要相当长的时间。下面是代码。

我意识到我有3个嵌套for循环,这可能是瓶颈发生的地方。如果我在“框”中处理图像,也就是说加载一个[500,500]的数组并且迭代图像处理时间非常短。不幸的是,相机错误要求我在非常长的条带(52,000 x 4)(y,x)中进行迭代以避免条带化。

任何有关加快这项工作的建议都将受到赞赏:

def box(dataset, outdataset, sampleSize, n):

    quiet = 0
    sample = sampleSize
    #iterate over all of the bands
    for j in xrange(1, dataset.RasterCount + 1): #1 based counter

        band = dataset.GetRasterBand(j)
        NDV = band.GetNoDataValue()

        print "Processing band: " + str(j)       

        #define the interval at which blocks are created
        intervalY = int(band.YSize/1)    
        intervalX = int(band.XSize/2000) #to be changed to sampleSize when working

        #iterate through the rows
        scanBlockCounter = 0

        for i in xrange(0,band.YSize,intervalY):

            #If the next i is going to fail due to the edge of the image/array
            if i + (intervalY*2) < band.YSize:
                numberRows = intervalY
            else:
                numberRows = band.YSize - i

            for h in xrange(0,band.XSize, intervalX):

                if h + (intervalX*2) < band.XSize:
                    numberColumns = intervalX
                else:
                    numberColumns = band.XSize - h

                scanBlock = band.ReadAsArray(h,i,numberColumns, numberRows).astype(numpy.float)

                standardDeviation = numpy.std(scanBlock)
                mean = numpy.mean(scanBlock)

                newMin = mean - (standardDeviation * n)
                newMax = mean + (standardDeviation * n)

                outputBlock = ((scanBlock - newMin)/(newMax-newMin))*255
                outRaster = outdataset.GetRasterBand(j).WriteArray(outputBlock,h,i)#array, xOffset, yOffset


                scanBlockCounter = scanBlockCounter + 1
                #print str(scanBlockCounter) + ": " + str(scanBlock.shape) + str(h)+ ", " + str(intervalX)
                if numberColumns == band.XSize - h:
                    break

                #update progress line
                if not quiet:
                    gdal.TermProgress_nocb( (float(h+1) / band.YSize) )

这是一个更新: 不使用配置文件模块,因为我不想开始将代码的小部分包装到函数中,所以我使用了print和exit语句的混合来大致了解哪些行花费的时间最多。幸运的是(我确实理解我是多么幸运)一条线拖着一切。

    outRaster = outdataset.GetRasterBand(j).WriteArray(outputBlock,h,i)#array, xOffset, yOffset

在打开输出文件并写出数组时,GDAL似乎效率很低。考虑到这一点,我决定将修改后的数组“outBlock”添加到python列表中,然后写出块。这是我改变的部分:

刚刚修改了outputBlock ......

         #Add the array to a list (tuple)
            outputArrayList.append(outputBlock)

            #Check the interval counter and if it is "time" write out the array
            if len(outputArrayList) >= (intervalX * writeSize) or finisher == 1:

                #Convert the tuple to a numpy array.  Here we horizontally stack the tuple of arrays.
                stacked = numpy.hstack(outputArrayList)

                #Write out the array
                outRaster = outdataset.GetRasterBand(j).WriteArray(stacked,xOffset,i)#array, xOffset, yOffset
                xOffset = xOffset + (intervalX*(intervalX * writeSize))

                #Cleanup to conserve memory
                outputArrayList = list()
                stacked = None
                finisher=0

Finisher只是一个处理边缘的旗帜。花了一些时间来弄清楚如何从列表中构建一个数组。在那,使用numpy.array创建一个三维数组(任何人都在解释为什么?)和写数组需要一个二维数组。现在总处理时间从不到2分钟到5分钟不等。知道为什么时间范围可能存在吗?

非常感谢发布的所有人!下一步是真正进入Numpy并学习矢量化以进行额外的优化。

3 个答案:

答案 0 :(得分:6)

加速numpy数据操作的一种方法是使用vectorize。从本质上讲,vectorize采用函数f并创建一个新函数g,将f映射到数组a上。然后调用g,如下所示:g(a)

>>> sqrt_vec = numpy.vectorize(lambda x: x ** 0.5)
>>> sqrt_vec(numpy.arange(10))
array([ 0.        ,  1.        ,  1.41421356,  1.73205081,  2.        ,
        2.23606798,  2.44948974,  2.64575131,  2.82842712,  3.        ])

如果没有您正在使用的数据,我无法确定这是否有帮助,但也许您可以将上述内容重写为一组可以vectorized的函数。也许在这种情况下,您可以将一系列索引向量化为ReadAsArray(h,i,numberColumns, numberRows)。以下是潜在好处的一个例子:

>>> print setup1
import numpy
sqrt_vec = numpy.vectorize(lambda x: x ** 0.5)
>>> print setup2
import numpy
def sqrt_vec(a):
    r = numpy.zeros(len(a))
    for i in xrange(len(a)):
        r[i] = a[i] ** 0.5
    return r
>>> timeit.timeit(stmt='a = sqrt_vec(numpy.arange(1000000))', setup=setup1, number=1)
0.30318188667297363
>>> timeit.timeit(stmt='a = sqrt_vec(numpy.arange(1000000))', setup=setup2, number=1)
4.5400981903076172

加速15倍!另请注意,numpy切片优雅地处理ndarray的边缘:

>>> a = numpy.arange(25).reshape((5, 5))
>>> a[3:7, 3:7]
array([[18, 19],
       [23, 24]])

因此,如果您可以将ReadAsArray数据转换为ndarray,则无需执行任何边缘检查诡计。


关于重塑的问题 - 重塑根本不会从根本上改变数据。它只是改变numpy索引数据的“步幅”。当您调用reshape方法时,返回的值是数据的新视图;数据不会被复制或更改,旧视图也不会被复制或更改。

>>> a = numpy.arange(25)
>>> b = a.reshape((5, 5))
>>> a
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24])
>>> b
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])
>>> a[5]
5
>>> b[1][0]
5
>>> a[5] = 4792
>>> b[1][0]
4792
>>> a.strides
(8,)
>>> b.strides
(40, 8)

答案 1 :(得分:5)

按要求回答。

如果你是IO绑定的,你应该对你的读/写进行分块。尝试将~500 MB的数据转储到ndarray,处理所有数据,写出来然后获取下一个~500 MB。确保重用ndarray。

答案 2 :(得分:2)

我没有尝试完全理解你在做什么,我注意到你没有使用任何numpy slicesarray broadcasting,这两者都可以加速你的代码,或者至少,使其更具可读性。如果这些与你的问题没有密切关系,我表示道歉。