如何使用fit_generator通过注意力机制(keras / tf)训练seq2seq模型?

时间:2020-03-11 20:26:16

标签: python-3.x tensorflow keras nlp attention-model

我正在尝试某种方法,我不确定它是否应该以这种方式工作。即,我尝试从此处尝试以下脚本:[https://github.com/lazyprogrammer/machine_learning_examples/blob/master/nlp_class3/attention.py]来训练一些翻译模型。虽然这一切正常,但在获取更多数据(更大的数据集)进行训练时,我却停滞不前。我的想法是使用fit_generator而不是fit来批量输入,因为无法容纳内存中的数据。

我没有使用r = model.fit(...(从354行开始),而是尝试了:

def generator(y, X, batch_size=100):
    number_of_batches = 10  #samples_per_epoch/batch_size
    counter = 0
    csize = np.shape(y[0])[0]    
    shuffle_index = np.arange(csize)
    print(csize, shuffle_index[0:6], shuffle_index[-6:-1])
    np.random.shuffle(shuffle_index)
    X =  X[shuffle_index, :, :]  #X.shape
    y =  [y[0][shuffle_index], y[1][shuffle_index], y[2][shuffle_index], y[3][shuffle_index]]
    while 1:
        index_batch = shuffle_index[batch_size*counter:batch_size*(counter+1)]
        X_batch = X[index_batch, :, :]  #X_batch.shape
        y_batch =  [y[0][index_batch], y[1][index_batch], y[2][index_batch], y[3][index_batch]]
        counter += 1
        print(counter, y[0][index_batch].shape, y[1][index_batch].shape, y[2][index_batch].shape, y[3][index_batch].shape, X_batch.shape)
        yield(y_batch, X_batch)
        if (counter < number_of_batches):
            np.random.shuffle(shuffle_index)
            counter=0

r = model.fit_generator(generator([encoder_inputs, decoder_inputs, z, z], decoder_targets_one_hot, 200),
                        epochs=EPOCHS, validation_freq=0.2, steps_per_epoch=10, shuffle=False)

但是出了点问题;我遇到了错误。这是最终的打印输出:

填充预训练的嵌入... C:\ Users \ cp \ Anaconda3 \ lib \ site-packages \ tensorflow_core \ python \ framework \ indexed_slices.py:433: 用户警告:将稀疏的IndexedSlices转换为一个密集的Tensor 形状未知。这可能会占用大量内存。
“将稀疏的IndexedSlices转换为形状未知的密集Tensor。” 时代1/50 C:\ Users \ cp \ Anaconda3 \ lib \ site-packages \ keras \ utils \ data_utils.py:718: UserWarning:无法检索输入。可能是因为 工人死亡。丢失的样本我们没有任何信息。
用户警告)1(200,255)(200,231)(200,256)(200,256)(200,231, 1880)1(200,255)(200,231)(200,256)(200,256)(200,231,1880)1 (200,255)(200,231)(200,256)(200,256)(200,231,1880)1(200, 255)(200、231)(200、256)(200、256)(200、231、1880)1(200、255) (200,231)(200,256)(200,256)(200,231,1880)1(200,255)(200, 231)(200,256)(200,256)(200,231,1880)1(200,255)(200,231) (200,256)(200,256)(200,231,1880)1(200,255)(200,231)(200, 256)(200、256)(200、231、1880)1(200、255)(200、231)(200、256) (200,256)(200,231,1880)

然后内核死亡,或者我必须停止它。它没有按预期帮助我,也没有限制内存使用。对于大型3D矩阵,甚至可以尝试np.memmap,但情况甚至更糟。

我显然违反了一些我不知道的规则-这是我目前的估计,因此我可以放弃我的想法。还是我接近了?请给我一些提示。

0 个答案:

没有答案