使用LSTM单元训练RNN时的RAM内存使用情况

时间:2020-01-22 10:40:38

标签: python keras lstm recurrent-neural-network

我正在跟踪有关递归神经网络的教程,并且正在训练RNN,以学习如何根据给定的字母序列来预测字母表中的下一个字母。问题是,我训练网络的每个时期,我的RAM使用率都在缓慢上升。我无法完成该网络的培训,因为我只有8192MB的RAM内存,并且在+-100个纪元后耗尽了。 这是为什么?我认为这与LSTM的工作方式有关,因为它们确实将一些信息保留在内存中,但是如果有人可以向我解释更多详细信息,那就太好了。

我使用的代码相对简单,并且完全独立(您可以复制/粘贴并运行它,不需要外部数据集,因为该数据集只是字母)。因此,我将其全部包含在内,因此该问题很容易重现。

我正在使用的tensorflow版本是1.14。

import numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.utils import np_utils
from keras_preprocessing.sequence import pad_sequences
np.random.seed(7)

# define the raw dataset
alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

# create mapping of characters to integers (0-25) and the reverse
char_to_int = dict((c, i) for i, c in enumerate(alphabet))
int_to_char = dict((i, c) for i, c in enumerate(alphabet))

num_inputs = 1000
max_len = 5
dataX = []
dataY = []
for i in range(num_inputs):
    start = np.random.randint(len(alphabet)-2)
    end = np.random.randint(start, min(start+max_len,len(alphabet)-1))
    sequence_in = alphabet[start:end+1]
    sequence_out = alphabet[end + 1]
    dataX.append([char_to_int[char] for char in sequence_in])
    dataY.append(char_to_int[sequence_out])
    print(sequence_in, "->" , sequence_out)

#Pad sequences with 0's, reshape X, then normalize data
X = pad_sequences(dataX, maxlen=max_len, dtype= "float32" )
X = np.reshape(X, (X.shape[0], max_len, 1))
X = X / float(len(alphabet))
print(X.shape)

#OHE the output variable.
y = np_utils.to_categorical(dataY)

#Create & fit the model
batch_size=1
model = Sequential()
model.add(LSTM(32, input_shape=(X.shape[1], 1)))
model.add(Dense(y.shape[1], activation= "softmax" ))
model.compile(loss= "categorical_crossentropy" , optimizer= "adam" , metrics=[ "accuracy" ])
model.fit(X, y, epochs=500, batch_size=batch_size, verbose=2)

1 个答案:

答案 0 :(得分:0)

问题是您的序列很长(1000个连续输入)。由于LSTM单元确实会在某个时期保持某种状态,并且您正尝试对其进行500个时期的训练(这很多),尤其是在CPU上进行训练时,RAM会随着时间的流逝而泛滥。我建议您尝试在GPU上进行训练,GPU具有自己的专用内存。另请检查以下问题:https://github.com/Element-Research/rnn/issues/5