我正在使用支持多处理的批处理生成器在 tf 2.3.1 中训练模型。这是必需的,因为我之前收到了一条错误消息,该消息似乎不再出现。
但是,这在 use_multiprocessing 选项设置为 False 时有效。然而,这不应该实际使用多处理,而且训练过程也很慢。我尝试将选项设置为 True 和 workers=2,但训练似乎从未开始,因为它一直停留在 Epoch 1/30。我什至设置了 batch_size=1,以尽量减少内存使用,以防万一。
生成器:
class waveformGenerator:
def __init__(self, waveform_path_list, batch_size, load_stats=None, window_size=None, iter_before_reset=10, task='train', Nfft=512, stride=0.25):
self.waveform_path_list = waveform_path_list
self.batch_size = batch_size
self.load_stats = load_stats
self.window_size = window_size
self.iter_before_reset = iter_before_reset
self.task = task
self.Nfft = Nfft
self.stride = stride
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
# batch counter
batch_idx = 0
# deepcopy X_paths_list
updated_waveform_path_list = copy.deepcopy(self.waveform_path_list)
# sox transformers
tfm_downsample, tfm, tfm_y = define_sox()
tfm_list = [tfm_downsample, tfm, tfm_y]
# initialize runstats (real and imaginary part)
runstats_x = RunningStats(1, np.float32)
runstats_y = RunningStats(1, np.float32)
# instantiate filterbank
fb = Tensorflow_FilterBank(Nfft=Nfft, stride=stride)
# calculate split waveforms lengths
y_length = self.Nfft # one frame is equivalent to y_length samples
x_length = self.window_size*y_length
# define step size
self.stepSize = np.round(stride*Nfft).astype(int)
### START ITERATIONS
while True:
# select batch_size unique random elements from path list, then update list
updated_waveform_path_list, path_selection = update_list(updated_waveform_path_list, self.batch_size, self.task)
# fill waveforms lists
x_array_list = []
y_array_list = []
for file_name in path_selection:
x_array, y_array = make_in_out_wave(file_name, tfm_list)
x_array_list.append(x_array)
y_array_list.append(y_array)
# calculate memory to preallocate for each batch along time dimension
_, max_index = maxdim_to_mem_preall(x_array_list)
mem_preall_frames, _ = self.get_padded_waveform(x_array_list[max_index])
# initialize tensors
X_batch = np.zeros([len(x_array_list), mem_preall_frames, x_length])
Y_batch = np.zeros([len(x_array_list), mem_preall_frames, y_length])
### CREATE DATA BATCHES ###
for i in range(len(x_array_list)):
# select spectrograms
x = x_array_list[i]
y = y_array_list[i]
# divide waveforms into windows
x_frames, y_central, n_windows = self.make_waveforms_windows(x, y, x_length, y_length)
del x, y
# fill batch tensors
X_batch[i, :n_windows, :x_frames.shape[1]] = x_frames
Y_batch[i, :n_windows, :y_central.shape[1]] = y_central
del x_frames, y_central
del x_array_list, y_array_list
# mask zeros
X_batch = np.ma.masked_equal(X_batch,0)
Y_batch = np.ma.masked_equal(Y_batch,0)
X_out = X_batch.reshape((X_batch.shape[0]*X_batch.shape[1], X_batch.shape[2] ))
Y_out = Y_batch.reshape((Y_batch.shape[0]*Y_batch.shape[1], Y_batch.shape[2] ))
del X_batch, Y_batch
return X_out, Y_out ### yield or return?
del X_out, Y_out
# increase batch counter
batch_idx += 1
### RESET GENERATOR
if batch_idx == iter_before_reset:
batch_idx = 0
updated_waveform_path_list = copy.deepcopy(waveform_path_list)
if task == 'test':
break
调用训练和验证生成器
# create instance of feature_generator
datagen = waveformGenerator(train_dataset_list[:32], batch_size, load_stats=None, window_size=window_size, iter_before_reset=n_iter, task='train', Nfft=Nfft, stride=stride)
datagen_dev = waveformGenerator(dev_dataset_list, 8, load_stats=None, window_size=window_size, iter_before_reset=n_iter_dev, task='dev', Nfft=Nfft, stride=stride)
适合:
history = waveform_net.fit(datagen, steps_per_epoch=n_iter, epochs=n_epochs, validation_data=datagen_dev, validation_steps=n_iter_dev, callbacks=[checkpoint_best, tensorboard_callback], use_multiprocessing=False, verbose=1)
我读到这应该发生在 tf 2.0 中,但由于我使用的是更新版本,我希望这个错误已经得到修复。对此有什么想法吗?