我遵循了此处提供的Keras fit_generator
线程安全生成器的示例:https://keras.io/utils/#sequence
看来批次(idx
)的索引已锁定到每个线程。就我而言,我想将线程锁定到示例索引。这是我的实现:
class CustomGenerator():
def __init__(self):
self.input_ = np.arange(0, 1000)
self.labels = np.arange(0, 1000) * 0.1
self.batch_sz = 5
self.example_index = 0
def __len__(self):
return np.ceil(len(self.input_) / float(self.batch_sz))
def __getitem__(self, batch_idx):
batch_x = np.zeros(self.batch_sz)
batch_y = np.zeros(self.batch_sz)
row = 0
while row < self.batch_sz:
if self.example_index % 2 == 0:
batch_x[row] = self.input_[self.example_index]
batch_y[row] = self.labels[self.example_index]
row += 1
self.example_index += 1
return batch_x, batch_y
cg = CustomGenerator()
batch_idx = 0
while True:
print(cg.__getitem__(batch_idx))
batch_idx += 1
它输出正确的输出:
(array([0., 2., 4., 6., 8.]), array([0. , 0.2, 0.4, 0.6, 0.8]))
(array([10., 12., 14., 16., 18.]), array([1. , 1.2, 1.4, 1.6, 1.8]))
(array([20., 22., 24., 26., 28.]), array([2. , 2.2, 2.4, 2.6, 2.8]))
(array([30., 32., 34., 36., 38.]), array([3. , 3.2, 3.4, 3.6, 3.8]))
如何确保该实现以线程安全的方式工作,即在生成批处理时,不同的工作人员将不会使用相同的example_index
。