我一直在尝试使用Google的基于RNN的seq2seq model.
我一直在训练一个文本摘要模型,并提供大约1GB大小的文本数据。该模型快速填满我的整个RAM(8GB),开始填满交换内存(进一步8GB)和崩溃后我必须做一个硬关机。
我的LSTM网络的配置如下:
model: AttentionSeq2Seq
model_params:
attention.class: seq2seq.decoders.attention.AttentionLayerDot
attention.params:
num_units: 128
bridge.class: seq2seq.models.bridges.ZeroBridge
embedding.dim: 128
encoder.class: seq2seq.encoders.BidirectionalRNNEncoder
encoder.params:
rnn_cell:
cell_class: GRUCell
cell_params:
num_units: 128
dropout_input_keep_prob: 0.8
dropout_output_keep_prob: 1.0
num_layers: 1
decoder.class: seq2seq.decoders.AttentionDecoder
decoder.params:
rnn_cell:
cell_class: GRUCell
cell_params:
num_units: 128
dropout_input_keep_prob: 0.8
dropout_output_keep_prob: 1.0
num_layers: 1
optimizer.name: Adam
optimizer.params:
epsilon: 0.0000008
optimizer.learning_rate: 0.0001
source.max_seq_len: 50
source.reverse: false
target.max_seq_len: 50
我尝试将批量大小从32减少到16,但它仍然没有帮助。我应该做些什么具体的改变,以防止我的模型占用整个RAM并崩溃? (如减少数据大小,减少堆叠的LSTM单元数量,进一步减少批量大小等)
我的系统运行Python 2.7x,TensorFlow版本1.1.0和CUDA 8.0。该系统配备了Nvidia Geforce GTX-1050Ti(768 CUDA内核),内存为4GB,系统内存为8GB,另外还有8GB交换内存。
答案 0 :(得分:0)
你的模特看起来很小。唯一有点大的是火车数据。请检查以确保您的get_batch()
功能没有错误。如果存在错误,您可能实际上正在加载整个数据集以进行培训。
为了快速证明这一点,只需将训练数据大小减小到非常小的范围(例如当前大小的1/10),看看是否有帮助。请注意,它不应该有用,因为您使用的是迷你批处理。但如果可以解决问题,请修复您的private String tableDetails;
private String logpath;
功能。