从algo中删除numpy concat

时间:2015-12-14 21:17:43

标签: python numpy

我有一个名为gen_data的函数,它将通过列表进行单次传递并构建一个3D数组。然后我遍历列表列表,应用函数gen_data,然后将结果连接在一起。

fst = lambda x: x[0]
snd = lambda x: x[1]

def gen_data(data,p=0, batch_size = BATCH_SIZE, n_session = N_SESSION, 
    x = np.zeros((batch_size,SEQ_LENGTH,vocab_size))
    y = np.zeros(batch_size)

    for n in range(batch_size):
        ptr = n
        for i in range(SEQ_LENGTH):
            x[n,i,char_to_ix[data[p+ptr+i]]] = 1.
        if(return_target):
            y[n] = char_to_ix[data[p+ptr+SEQ_LENGTH]]
    return x, np.array(y,dtype='int32')

def batch_data(data):
    nest = [gen_data(datum) for datum in data]
    x = np.concatenate(map(fst,nest))
    y = np.concatenate(map(snd,nest))
    return (x,y)

组合这些函数的最佳方法是什么,所以我不需要通过数据进行多次传递来连接结果?

澄清一下,目标是删除zip / concat / splat / list comp的一般需求。为了能够将x张量初始化为正确的尺寸,然后在一次遍历中迭代每个数据/ SEQ_LENGTH,batch_size。

1 个答案:

答案 0 :(得分:0)

没有测试的东西,这里有一些快速修复:

def gen_data(data,p=0, batch_size = BATCH_SIZE, n_session = N_SESSION, 
    x = np.zeros((batch_size,SEQ_LENGTH,vocab_size))
    y = np.zeros(batch_size, dtype=int)   # initial to desired type

    for n in range(batch_size):
        ptr = n
        for i in range(SEQ_LENGTH):
            x[n,i,char_to_ix[data[p+ptr+i]]] = 1.
        if(return_target):
            y[n] = char_to_ix[data[p+ptr+SEQ_LENGTH]]
    return x, y 
    # y is already an array; don't need this: np.array(y,dtype='int32')
我认为

nest = [gen_data(datum) for datum in data]会产生

[(x0,y0), (x1,y1),...]其中x为3d(n,m,y),y为1d(n)

x = np.concatenate([n[0] for n in nest])(我喜欢这种格式而不是映射)对我来说没问题。与所有列表理解操作相比,concatenate相对便宜。看看np.vstack等的内容,看看那些如何使用理解以及连接。

一个小例子:

In [515]: def gen():
    return np.arange(8).reshape(2,4),np.arange(1,3)
   .....: 

In [516]: gen()
Out[516]: 
(array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), array([1, 2]))

In [517]: nest=[gen() for _ in range(3)]

In [518]: nest
Out[518]: 
[(array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), array([1, 2])),
 (array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), array([1, 2])),
 (array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), array([1, 2]))]

In [519]: np.concatenate([x[0] for x in nest])
Out[519]: 
array([[0, 1, 2, 3],
       [4, 5, 6, 7],
       [0, 1, 2, 3],
       [4, 5, 6, 7],
       [0, 1, 2, 3],
       [4, 5, 6, 7]])

In [520]: np.concatenate([x[1] for x in nest])
Out[520]: array([1, 2, 1, 2, 1, 2])

zip*实际上有效地解决了这个问题。在嵌套列表上,因此可以使用以下内容构造数组:

In [532]: nest1=zip(*nest)

In [533]: np.concatenate(nest1[0])
Out[533]: 
array([[0, 1, 2, 3],
       [4, 5, 6, 7],
       [0, 1, 2, 3],
       [4, 5, 6, 7],
       [0, 1, 2, 3],
       [4, 5, 6, 7]])

In [534]: np.concatenate(nest1[1])
Out[534]: array([1, 2, 1, 2, 1, 2])

仍需要连接。

由于nest是元组列表,因此它可以作为结构化数组的输入:

In [524]: arr=np.array(nest,dtype=[('x','(2,4)int'),('y','(2,)int')])

In [525]: arr['x']
Out[525]: 
array([[[0, 1, 2, 3],
        [4, 5, 6, 7]],

       [[0, 1, 2, 3],
        [4, 5, 6, 7]],

       [[0, 1, 2, 3],
        [4, 5, 6, 7]]])

In [526]: arr['y']
Out[526]: 
array([[1, 2],
       [1, 2],
       [1, 2]])

另一种可能性是初始化xy,并进行迭代。但是你已经在gen_data中这样做了。唯一新的事情就是我要分配更大的块。

x = ...
y = ...
for i in range(...):
    x[i,...], y[i] = gen(data[i])

我更喜欢理解解决方案,但我不会推测速度。

就速度而言,我认为gen_data中的低级迭代是时间消费者。连接较大的块相对较快。

另一个想法 - 因为你正在迭代gen_data中的数组行,如何将视图传递给该函数,并迭代这些。

def gen_data(data,x=None,y=None):
    # accept array or make own
    if x is None:
        x = np.zeros((3,4),int)
    if y is None:
        y = np.zeros(3,int)
    for n in range(3):
        x[n,...] = np.arange(4)+n
        y[n] = n
    return x,y

没有输入,像以前一样生成数组:

In [543]: gen_data(None)
Out[543]: 
(array([[0, 1, 2, 3],
       [1, 2, 3, 4],
       [2, 3, 4, 5]]),
 array([0, 1, 2]))

或初始化一对,并迭代视图:

In [544]: x,y = np.zeros((9,4),int),np.zeros(9,int)

In [546]: for i in range(0,9,3):
   .....:     gen_data(None,x[i:i+3,...],y[i:i+3])

In [547]: x
Out[547]: 
array([[0, 1, 2, 3],
       [1, 2, 3, 4],
       [2, 3, 4, 5],
       [0, 1, 2, 3],
       [1, 2, 3, 4],
       [2, 3, 4, 5],
       [0, 1, 2, 3],
       [1, 2, 3, 4],
       [2, 3, 4, 5]])
In [548]: y
Out[548]: array([0, 1, 2, 0, 1, 2, 0, 1, 2])