我正在尝试某种方法,我不确定它是否应该以这种方式工作。即,我尝试从此处尝试以下脚本:[https://github.com/lazyprogrammer/machine_learning_examples/blob/master/nlp_class3/attention.py]来训练一些翻译模型。虽然这一切正常,但在获取更多数据(更大的数据集)进行训练时,我却停滞不前。我的想法是使用fit_generator
而不是fit
来批量输入,因为无法容纳内存中的数据。
我没有使用r = model.fit(...
(从354行开始),而是尝试了:
def generator(y, X, batch_size=100):
number_of_batches = 10 #samples_per_epoch/batch_size
counter = 0
csize = np.shape(y[0])[0]
shuffle_index = np.arange(csize)
print(csize, shuffle_index[0:6], shuffle_index[-6:-1])
np.random.shuffle(shuffle_index)
X = X[shuffle_index, :, :] #X.shape
y = [y[0][shuffle_index], y[1][shuffle_index], y[2][shuffle_index], y[3][shuffle_index]]
while 1:
index_batch = shuffle_index[batch_size*counter:batch_size*(counter+1)]
X_batch = X[index_batch, :, :] #X_batch.shape
y_batch = [y[0][index_batch], y[1][index_batch], y[2][index_batch], y[3][index_batch]]
counter += 1
print(counter, y[0][index_batch].shape, y[1][index_batch].shape, y[2][index_batch].shape, y[3][index_batch].shape, X_batch.shape)
yield(y_batch, X_batch)
if (counter < number_of_batches):
np.random.shuffle(shuffle_index)
counter=0
r = model.fit_generator(generator([encoder_inputs, decoder_inputs, z, z], decoder_targets_one_hot, 200),
epochs=EPOCHS, validation_freq=0.2, steps_per_epoch=10, shuffle=False)
但是出了点问题;我遇到了错误。这是最终的打印输出:
填充预训练的嵌入... C:\ Users \ cp \ Anaconda3 \ lib \ site-packages \ tensorflow_core \ python \ framework \ indexed_slices.py:433: 用户警告:将稀疏的IndexedSlices转换为一个密集的Tensor 形状未知。这可能会占用大量内存。
“将稀疏的IndexedSlices转换为形状未知的密集Tensor。” 时代1/50 C:\ Users \ cp \ Anaconda3 \ lib \ site-packages \ keras \ utils \ data_utils.py:718: UserWarning:无法检索输入。可能是因为 工人死亡。丢失的样本我们没有任何信息。
用户警告)1(200,255)(200,231)(200,256)(200,256)(200,231, 1880)1(200,255)(200,231)(200,256)(200,256)(200,231,1880)1 (200,255)(200,231)(200,256)(200,256)(200,231,1880)1(200, 255)(200、231)(200、256)(200、256)(200、231、1880)1(200、255) (200,231)(200,256)(200,256)(200,231,1880)1(200,255)(200, 231)(200,256)(200,256)(200,231,1880)1(200,255)(200,231) (200,256)(200,256)(200,231,1880)1(200,255)(200,231)(200, 256)(200、256)(200、231、1880)1(200、255)(200、231)(200、256) (200,256)(200,231,1880)
然后内核死亡,或者我必须停止它。它没有按预期帮助我,也没有限制内存使用。对于大型3D矩阵,甚至可以尝试np.memmap
,但情况甚至更糟。
我显然违反了一些我不知道的规则-这是我目前的估计,因此我可以放弃我的想法。还是我接近了?请给我一些提示。