Numpy切片与批量大小

时间:2018-02-09 09:38:51

标签: python numpy numpy-slicing

我有一个形状为A的numpy数组(550,10)。我的批量大小为100,即我想从A获得多少数据行。在每次迭代中,我想从A中提取100行。但是当我到达最后50行时,我想要来自A的最后50行和前50行。

我有这样的功能:

def train(index, batch_size):

    if(batch_size + index < A.shape(0)):
          data_end_index = index + batch_size
          batch_data = A[index:batch_end_index,:]
    else:
          data_end_index = index + batch_size - A.shape(0) #550+100-600 = 50
          batch_data = A[500 to 549 and 0 to 49] # How to slice here ?

如何执行最后一步?

2 个答案:

答案 0 :(得分:0)

你可以尝试:

import numpy as np
data=np.random.rand(550,10)
batch_size=100

for index in range(0,data.shape[0],batch_size):
    batch=data[index:min(index+batch_size,data.shape[0]),:]
    print(batch.shape)

输出:

(100, 10)
(100, 10)
(100, 10)
(100, 10)
(100, 10)
(50, 10)

答案 1 :(得分:0)

使用numpy.split

窃取riccardo的示例数据
data=np.random.rand(550,10)
batch_size=100

q, block_end = data.shape[0] // batch_size, q * batch_size

batch = np.split(data[:block_end], q) + [data[block_end:]]

[*map(np.shape, batch)]
Out[89]: [(100, 10), (100, 10), (100, 10), (100, 10), (100, 10), (50, 10)]