Tensorflow:“use_multiprocessing=True”冻结模型训练

时间:2021-02-24 11:46:45

标签: python generator tensorflow2.0

我正在使用支持多处理的批处理生成器在 tf 2.3.1 中训练模型。这是必需的,因为我之前收到了一条错误消息,该消息似乎不再出现。

但是,这在 use_multiprocessing 选项设置为 False 时有效。然而,这不应该实际使用多处理,而且训练过程也很慢。我尝试将选项设置为 True 和 workers=2,但训练似乎从未开始,因为它一直停留在 Epoch 1/30。我什至设置了 batch_size=1,以尽量减少内存使用,以防万一。

生成器:

class waveformGenerator:
    
    def __init__(self, waveform_path_list, batch_size, load_stats=None, window_size=None, iter_before_reset=10, task='train', Nfft=512, stride=0.25):
        self.waveform_path_list = waveform_path_list
        self.batch_size = batch_size
        self.load_stats = load_stats
        self.window_size = window_size
        self.iter_before_reset = iter_before_reset
        self.task = task
        self.Nfft = Nfft
        self.stride = stride
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        
        with self.lock:
        
            # batch counter
            batch_idx = 0

            # deepcopy X_paths_list
            updated_waveform_path_list = copy.deepcopy(self.waveform_path_list)

            # sox transformers
            tfm_downsample, tfm, tfm_y = define_sox()
            tfm_list = [tfm_downsample, tfm, tfm_y]

            # initialize runstats (real and imaginary part)
            runstats_x = RunningStats(1, np.float32)
            runstats_y = RunningStats(1, np.float32)

            # instantiate filterbank
            fb = Tensorflow_FilterBank(Nfft=Nfft, stride=stride)

            # calculate split waveforms lengths
            y_length = self.Nfft # one frame is equivalent to y_length samples
            x_length = self.window_size*y_length

            # define step size
            self.stepSize = np.round(stride*Nfft).astype(int)


            ### START ITERATIONS
            while True:

                # select batch_size unique random elements from path list, then update list
                updated_waveform_path_list, path_selection = update_list(updated_waveform_path_list, self.batch_size, self.task)  

                # fill waveforms lists
                x_array_list = []
                y_array_list = []

                for file_name in path_selection:
                    x_array, y_array = make_in_out_wave(file_name, tfm_list)
                    x_array_list.append(x_array)
                    y_array_list.append(y_array)

                # calculate memory to preallocate for each batch along time dimension
                _, max_index = maxdim_to_mem_preall(x_array_list)
                mem_preall_frames, _ = self.get_padded_waveform(x_array_list[max_index])

                # initialize tensors
                X_batch = np.zeros([len(x_array_list), mem_preall_frames, x_length])
                Y_batch = np.zeros([len(x_array_list), mem_preall_frames, y_length])

                ### CREATE DATA BATCHES ###
                for i in range(len(x_array_list)):

                    # select spectrograms
                    x = x_array_list[i]
                    y = y_array_list[i]

                    # divide waveforms into windows
                    x_frames, y_central, n_windows = self.make_waveforms_windows(x, y, x_length, y_length)

                    del x, y

                    # fill batch tensors
                    X_batch[i, :n_windows, :x_frames.shape[1]] = x_frames
                    Y_batch[i, :n_windows, :y_central.shape[1]] = y_central

                    del x_frames, y_central

                del x_array_list, y_array_list

                # mask zeros
                X_batch = np.ma.masked_equal(X_batch,0)
                Y_batch = np.ma.masked_equal(Y_batch,0)

                X_out = X_batch.reshape((X_batch.shape[0]*X_batch.shape[1], X_batch.shape[2] ))
                Y_out = Y_batch.reshape((Y_batch.shape[0]*Y_batch.shape[1], Y_batch.shape[2] ))

                del X_batch, Y_batch

                return X_out, Y_out ### yield or return?

                del X_out, Y_out

                # increase batch counter
                batch_idx += 1

                ### RESET GENERATOR
                if batch_idx == iter_before_reset:
                    batch_idx = 0
                    updated_waveform_path_list = copy.deepcopy(waveform_path_list)

                    if task == 'test':
                        break

调用训练和验证生成器

# create instance of feature_generator
datagen = waveformGenerator(train_dataset_list[:32], batch_size, load_stats=None, window_size=window_size, iter_before_reset=n_iter, task='train', Nfft=Nfft, stride=stride)
datagen_dev = waveformGenerator(dev_dataset_list, 8, load_stats=None, window_size=window_size, iter_before_reset=n_iter_dev, task='dev', Nfft=Nfft, stride=stride)

适合:

history = waveform_net.fit(datagen, steps_per_epoch=n_iter, epochs=n_epochs, validation_data=datagen_dev, validation_steps=n_iter_dev, callbacks=[checkpoint_best, tensorboard_callback], use_multiprocessing=False, verbose=1)

我读到这应该发生在 tf 2.0 中,但由于我使用的是更新版本,我希望这个错误已经得到修复。对此有什么想法吗?

0 个答案:

没有答案