假设我有这个numpy数组
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]
我想分两批拆分,然后迭代:
[[1, 2, 3], Batch 1
[4, 5, 6]]
[[7, 8, 9], Batch 2
[10, 11, 12]]
最简单的方法是什么?
编辑:我很抱歉我错过了这样的信息:一旦我打算继续进行迭代,原始数组会因为分割和迭代批量而被破坏。批量迭代完成后,我需要从第一批重新开始,因此我应该保留原始数组不会被销毁。整个想法是与需要迭代批量的随机梯度下降算法一致。在一个典型的例子中,我可以进行100000次迭代For循环,只需要一次又一次地重放1000个批次。
答案 0 :(得分:8)
考虑数组<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.12.4/jquery.min.js"></script>
<p class="wff-post-text" style=" "><a class="wff-link-tab" href="http://facebook.com/1100591156696075" style="color: indigo;" target="_blank">Visby hemtjänst</a> updated cover photo.</p>
a
选项1
使用a = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]])
和reshape
//
选项2
如果你想要两个小组而不是两个小组
a.reshape(a.shape[0] // 2, -1, a.shape[1])
array([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
选项3
使用发电机
a.reshape(-1, 2, a.shape[1])
array([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
答案 1 :(得分:7)
您可以使用numpy.split
沿第一个轴n
分割,其中n
是所需批次的数量。因此,实现看起来像这样 -
np.split(arr,n,axis=0) # n is number of batches
由于axis
的默认值为0
本身,因此我们可以跳过设置。所以,我们只需要 -
np.split(arr,n)
样品运行 -
In [132]: arr # Input array of shape (10,3)
Out[132]:
array([[170, 52, 204],
[114, 235, 191],
[ 63, 145, 171],
[ 16, 97, 173],
[197, 36, 246],
[218, 75, 68],
[223, 198, 84],
[206, 211, 151],
[187, 132, 18],
[121, 212, 140]])
In [133]: np.split(arr,2) # Split into 2 batches
Out[133]:
[array([[170, 52, 204],
[114, 235, 191],
[ 63, 145, 171],
[ 16, 97, 173],
[197, 36, 246]]), array([[218, 75, 68],
[223, 198, 84],
[206, 211, 151],
[187, 132, 18],
[121, 212, 140]])]
In [134]: np.split(arr,5) # Split into 5 batches
Out[134]:
[array([[170, 52, 204],
[114, 235, 191]]), array([[ 63, 145, 171],
[ 16, 97, 173]]), array([[197, 36, 246],
[218, 75, 68]]), array([[223, 198, 84],
[206, 211, 151]]), array([[187, 132, 18],
[121, 212, 140]])]
答案 2 :(得分:0)
这样做:
a = [[1, 2, 3],[4, 5, 6],
[7, 8, 9],[10, 11, 12]]
b = a[0:2]
c = a[2:4]
答案 3 :(得分:0)
这是我用来迭代的东西。我使用b.next()
方法生成索引,然后将输出传递给切片numpy数组,例如a[b.next()]
,其中a是一个numpy数组。
class Batch():
def __init__(self, total, batch_size):
self.total = total
self.batch_size = batch_size
self.current = 0
def next(self):
max_index = self.current + self.batch_size
indices = [i if i < self.total else i - self.total
for i in range(self.current, max_index)]
self.current = max_index % self.total
return indices
b = Batch(10, 3)
print(b.next()) # [0, 1, 2]
print(b.next()) # [3, 4, 5]
print(b.next()) # [6, 7, 8]
print(b.next()) # [9, 0, 1]
print(b.next()) # [2, 3, 4]
print(b.next()) # [5, 6, 7]
答案 4 :(得分:0)
为避免错误“数组拆分不会导致等分”,
np.array_split(arr, n, axis=0)
比np.split(arr, n, axis=0)
更好。
例如,
a = np.array([[170, 52, 204],
[114, 235, 191],
[ 63, 145, 171],
[ 16, 97, 173]])
然后
print(np.array_split(a, 2))
[array([[170, 52, 204],
[114, 235, 191]]), array([[ 63, 145, 171],
[ 16, 97, 173]])]
print(np.array_split(a, 3))
[array([[170, 52, 204],
[114, 235, 191]]), array([[ 63, 145, 171]]), array([[ 16, 97, 173]])]
但是,由于print(np.array_split(a, 3))
不是整数,4/3
会引发错误。