在Python中针对“索引0超出轴0的大小0的范围”的保护

时间:2019-10-28 08:29:35

标签: python python-3.x numpy matplotlib

我有一个代码,可以在其中获得函数tan()的图上点的特定分布 上下限受直线限制:

import matplotlib.pyplot as plt
import numpy as np
import sys
import itertools
import multiprocessing
import tqdm

ic = range(1,10)
jc = range(1,10)

paramlist = list(itertools.product(ic,jc))

def func(params):
        ic = params[0]
        jc = params[1]

        fig = plt.figure(1, figsize=(10,6))

        x_all   = np.linspace(0, 10*np.pi, 10000, endpoint=False)
        x_above = x_all[ (-0.01)*ic*x_all < np.tan(x_all) ]
        x       = x_above[ np.tan(x_above) < 0.01*jc*x_above ]
        y       = np.tan(x)
        y2      = 0.01*jc*x
        y3      = (-0.01)*ic*x

        y_up   = np.diff(y) > 0
        y_diff = np.where( y_up, np.diff(y), 0 )
        x_diff = np.where( y_up, np.diff(x), 0 )
        diffs  = np.sqrt( x_diff**2 + y_diff**2 )
        length = diffs.sum()

        numbers = [2,4,6,8,10,12,14,16,18,20]

        p2 = []
        for d in range(len(numbers)):
            cumlenth = np.cumsum(diffs)
            s = np.abs(np.diff(np.sign(cumlenth-numbers[d]))).astype(bool)
            c = np.argwhere(s)[0][0]
            p = x[c], y[c]
            p2.append(p)

        p3 = sorted(p2, key=lambda x: x[0])
        x_max = p3[len(p3)-1][0]
        p4 = sorted(p2, key=lambda x: x[1])
        y_min = p4[0][1]
        y_max = p4[len(p3)-1][1]


        for b in range(len(p2)):
            plt.scatter( p2[b][0], p2[b][1], color="crimson", s=8)

        plt.plot(x, np.tan(x))
        plt.plot(x, y2)
        plt.plot(x, y3)

        ax = plt.gca()
        ax.set_xlim([0, x_max+0.5])
        ax.set_ylim([y_min-0.5, y_max+0.5]) 

        plt.savefig('C:\\Users\\tkp\\Desktop\\wykresy_4\\i='+str(ic)+'_j='+str(jc)+'.png', bbox_inches='tight')
        plt.show()

if __name__ == '__main__':

    p = multiprocessing.Pool(4)
    for params in tqdm.tqdm(p.imap_unordered(func, paramlist), total=len(paramlist)):
        #pass
        sys.stdout.write('\r'+ str(params))
        sys.stdout.flush()

    p.close()
    p.join()

例如在哪里收到剧情:

enter image description here

问题是,如果我在x_all = np.linspace(0, 10*np.pi, 10000, endpoint=False)中设置的范围太小,则会出现错误index 0 is out of bounds for axis 0 with size 0。我该如何保护自己呢?或者在这种情况下,我可以在“ linspace”函数中设置一个可变范围?

1 个答案:

答案 0 :(得分:2)

此错误发生在哪里?这是一条基本信息-对我们来说,尤其是对您来说!

@edison说它在argwhere表达式中。我将尝试重新创建该步骤,首先猜测一下diffs的样子:

In [8]: x = np.ones(5)*.1                                                       
In [9]: x                                                                       
Out[9]: array([0.1, 0.1, 0.1, 0.1, 0.1])
In [10]: s = np.cumsum(x)                                                       
In [11]: s                                                                      
Out[11]: array([0.1, 0.2, 0.3, 0.4, 0.5])
In [12]: s-1                                                                    
Out[12]: array([-0.9, -0.8, -0.7, -0.6, -0.5])
In [13]: np.sign(s-1)                                                           
Out[13]: array([-1., -1., -1., -1., -1.])
In [14]: np.diff(np.sign(s-1))                                                  
Out[14]: array([0., 0., 0., 0.])
In [15]: np.abs(np.diff(np.sign(s-1)))                                          
Out[15]: array([0., 0., 0., 0.])
In [16]: np.abs(np.diff(np.sign(s-1))).astype(bool)                             
Out[16]: array([False, False, False, False])

不管到现在有什么细节,都可以猜测s是仅包含False的数组。 where在该数组中找到True个元素;没有。

In [17]: np.where(_)                                                            
Out[17]: (array([], dtype=int64),)

argwhere是它的转置-每个维度一列,每个找到的项目一行。

In [18]: np.argwhere(_)                                                         
Out[18]: array([], shape=(0, 2), dtype=int64)
In [19]: _[0]                                                                   
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-19-aa79beb95eae> in <module>
----> 1 _[0]

IndexError: index 0 is out of bounds for axis 0 with size 0

因此,您的第一道防线是检查返回数组的形状:

c = np.argwhere(s)
if c.shape[0]>0:
    c = c[0,0]
    p = x[c], y[c]
else:
    # what do you want to do if non of `s` are true?

您可以从那里向后工作,注意确保diffsnumbers正确,并始终找到有效的c。但是无论如何,在使用whereargwhere时,请小心假定已找到给定数量的项目。